import dataclasses
from typing import Any, Callable, Iterable, Optional

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement


@dataclasses.dataclass
class CudaGraphBenchParams:
    num_ops_in_cuda_graph: int


@dataclasses.dataclass
class ArgPool:
    """
    When some argument of the benchmarking function is annotated with this type,
    the benchmarking class (BenchMM) will collapse the argument to a pick a
    single value from the given list of values, during function invocation.
    For every invocation during a benchmarking run, it will choose a
    different value from the list.
    """
    values: Iterable[Any]

    def __getitem__(self, index):
        return self.values[index]


class Bench:

    class ArgsIterator:

        def __init__(self, args_list, kwargs_list):
            assert len(args_list) == len(kwargs_list)
            self.args_list = args_list
            self.kwargs_list = kwargs_list
            self.n = len(self.args_list)
            self.idx = 0

        def __next__(self):
            while True:
                yield (self.args_list[self.idx], self.kwargs_list[self.idx])
                self.idx += 1
                self.idx = self.idx % self.n

        def reset(self):
            self.idx = 0

        @property
        def n_args(self):
            return self.n

    def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
                 label: str, sub_label: str, description: str, fn: Callable,
                 *args, **kwargs):

        self.cuda_graph_params = cuda_graph_params
        self.use_cuda_graph = self.cuda_graph_params is not None
        self.label = label
        self.sub_label = sub_label
        self.description = description
        self.fn = fn

        # Process args
        self._args = args
        self._kwargs = kwargs
        self.args_list, self.kwargs_list = self.collapse_argpool(
            *args, **kwargs)
        self.args_iterator = self.ArgsIterator(self.args_list,
                                               self.kwargs_list)

        # Cudagraph runner
        self.g = None
        if self.use_cuda_graph:
            self.g = self.get_cuda_graph_runner()

        # benchmark run params
        self.min_run_time = 1

    def collapse_argpool(self, *args, **kwargs):
        argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [
            arg for arg in kwargs.values() if isinstance(arg, ArgPool)
        ]
        if len(argpool_args) == 0:
            return [args], [kwargs]

        # Make sure all argpools are of the same size
        argpool_size = len(argpool_args[0].values)
        assert all([argpool_size == len(arg.values) for arg in argpool_args])

        # create copies of the args
        args_list = []
        kwargs_list = []
        for _ in range(argpool_size):
            args_list.append(args)
            kwargs_list.append(kwargs.copy())

        for i in range(argpool_size):
            # collapse args; Just pick the ith value
            args_list[i] = tuple([
                arg[i] if isinstance(arg, ArgPool) else arg
                for arg in args_list[i]
            ])

            # collapse kwargs
            kwargs_i = kwargs_list[i]
            arg_pool_keys = [
                k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
            ]
            for k in arg_pool_keys:
                # again just pick the ith value
                kwargs_i[k] = kwargs_i[k][i]
            kwargs_list[i] = kwargs_i

        return args_list, kwargs_list

    def get_cuda_graph_runner(self):
        assert self.use_cuda_graph
        assert self.args_iterator is not None

        num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph

        # warmup
        args_it = self.args_iterator.__next__()
        for _ in range(2):
            args, kwargs = next(args_it)
            self.fn(*args, **kwargs)

        self.args_iterator.reset()
        args_it = self.args_iterator.__next__()
        stream = torch.cuda.Stream()
        with torch.cuda.stream(stream):
            g = torch.cuda.CUDAGraph()
            with torch.cuda.graph(g):
                for _ in range(num_graph_ops):
                    args, kwargs = next(args_it)
                    self.fn(*args, **kwargs)
        return g

    def run_cudagrah(self) -> TMeasurement:
        assert self.use_cuda_graph
        globals = {'g': self.g}

        return TBenchmark.Timer(
            stmt="g.replay()",
            globals=globals,
            label=(
                f"{self.label}"
                f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops"
            ),
            sub_label=self.sub_label,
            description=self.description,
        ).blocked_autorange(min_run_time=self.min_run_time)

    def run_eager(self) -> TMeasurement:
        setup = None
        stmt = None
        globals = None

        has_arg_pool = self.args_iterator.n_args > 1
        if has_arg_pool:
            setup = '''
                    args_iterator.reset()
                    args_it = args_iterator.__next__()
                    '''
            stmt = '''
                    args, kwargs = next(args_it)
                    fn(*args, **kwargs)
                    '''
            globals = {'fn': self.fn, 'args_iterator': self.args_iterator}
        else:
            # no arg pool. Just use the args and kwargs directly
            self.args_iterator.reset()
            args_it = self.args_iterator.__next__()
            args, kwargs = next(args_it)

            setup = ""
            stmt = '''
                    fn(*args, **kwargs)
                   '''
            globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs}

        return TBenchmark.Timer(
            stmt=stmt,
            setup=setup,
            globals=globals,
            label=self.label,
            sub_label=self.sub_label,
            description=self.description,
        ).blocked_autorange(min_run_time=self.min_run_time)

    def run(self) -> TMeasurement:
        timer = None
        if self.use_cuda_graph:  # noqa SIM108
            timer = self.run_cudagrah()
        else:
            timer = self.run_eager()
        if not timer.meets_confidence() or timer.has_warnings:
            print("Doesn't meet confidence - re-running bench ...")
            return self.run()
        return timer

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type:
            print(f"exc type {exc_type}")
            print(f"exc value {exc_value}")
            print(f"exc traceback {traceback}")
