# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import argparse
from datetime import datetime
from pathlib import Path
from typing import Optional

import pandas as pd

from nsys_recipe import log
from nsys_recipe.data_service import DataService
from nsys_recipe.lib import helpers, recipe, summary
from nsys_recipe.lib.args import ArgumentParser, Option
from nsys_recipe.log import logger


class CudaGpuMemSizeSum(recipe.Recipe):
    @staticmethod
    def _mapper_func(
        report_path: str, parsed_args: argparse.Namespace
    ) -> Optional[tuple[str, pd.DataFrame, pd.DataFrame]]:
        service = DataService(report_path, parsed_args)

        service.queue_table("StringIds")
        service.queue_table("ENUM_CUDA_MEMCPY_OPER")
        service.queue_table(
            "CUPTI_ACTIVITY_KIND_MEMCPY", ["bytes", "copyKind", "deviceId"]
        )
        service.queue_table("CUPTI_ACTIVITY_KIND_MEMSET", ["bytes", "deviceId"])

        df_dict = service.read_queued_tables()
        if df_dict is None:
            return None

        memcpy_df = df_dict["CUPTI_ACTIVITY_KIND_MEMCPY"].merge(
            df_dict["ENUM_CUDA_MEMCPY_OPER"],
            left_on="copyKind",
            right_on="id",
            how="left",
        )
        memset_df = df_dict["CUPTI_ACTIVITY_KIND_MEMSET"].assign(name="CUDA_MEMSET")

        memory_df = pd.concat([memcpy_df, memset_df])
        if memory_df.empty:
            logger.info(
                f"{report_path}: Report was successfully processed, but no data was found."
            )
            return None

        stats_df = summary.describe_column(memory_df.groupby("name")["bytes"])
        stats_df.index.name = "Name"

        stats_by_device_df = summary.describe_column(
            memory_df.groupby(["name", "deviceId"])["bytes"]
        )
        stats_by_device_df.index.names = ["Name", "Device ID"]

        filename = Path(report_path).stem
        return filename, stats_df, stats_by_device_df

    @log.time("Mapper")
    def mapper_func(
        self, context: recipe.Context
    ) -> list[Optional[tuple[str, pd.DataFrame, pd.DataFrame]]]:
        return context.wait(
            context.map(
                self._mapper_func,
                self._parsed_args.input,
                parsed_args=self._parsed_args,
            )
        )

    @log.time("Reducer")
    def reducer_func(
        self, mapper_res: list[Optional[tuple[str, pd.DataFrame, pd.DataFrame]]]
    ) -> None:
        filtered_res = helpers.filter_none_or_empty(mapper_res)
        # Sort by file name.
        filtered_res = sorted(filtered_res, key=lambda x: x[0])
        filenames, stats_dfs, stats_by_device_dfs = zip(*filtered_res)

        files_df = pd.DataFrame({"File": filenames}).rename_axis("Rank")
        files_df.to_parquet(self.add_output_file("files.parquet"))

        rank_stats_by_device_df = pd.concat(
            [df.assign(Rank=rank) for rank, df in enumerate(stats_by_device_dfs)]
        )
        rank_stats_by_device_df.to_parquet(
            self.add_output_file("rank_stats_by_device.parquet")
        )

        rank_stats_df = pd.concat(
            [df.assign(Rank=rank) for rank, df in enumerate(stats_dfs)]
        )
        rank_stats_df.to_parquet(self.add_output_file("rank_stats.parquet"))

        all_stats_df = summary.aggregate_stats_df(rank_stats_df, index_col="Name")
        all_stats_df.to_parquet(self.add_output_file("all_stats.parquet"))

        if self._parsed_args.csv:
            files_df.to_csv(self.add_output_file("files.csv"))
            all_stats_df.to_csv(self.add_output_file("all_stats.csv"))
            rank_stats_df.to_csv(self.add_output_file("rank_stats.csv"))
            rank_stats_by_device_df.to_csv(
                self.add_output_file("rank_stats_by_device.csv")
            )

    def save_notebook(self) -> None:
        self.create_notebook("stats.ipynb")
        self.add_notebook_helper_file("nsys_display.py")

    def save_analysis_file(self) -> None:
        self._analysis_dict.update(
            {
                "EndTime": str(datetime.now()),
                "Outputs": self._output_files,
            }
        )
        self.create_analysis_file()

    def run(self, context: recipe.Context) -> None:
        super().run(context)

        mapper_res = self.mapper_func(context)
        self.reducer_func(mapper_res)

        self.save_notebook()
        self.save_analysis_file()

    @classmethod
    def get_argument_parser(cls) -> ArgumentParser:
        parser = super().get_argument_parser()

        parser.add_recipe_argument(Option.INPUT, required=True)
        parser.add_recipe_argument(Option.CSV)

        return parser
