try:
    from benchmark import (
        benchmark_forward,
        benchmark_combined,
        benchmark_all,
        benchmark_memory_forward,
        benchmark_memory_combined,
    )
except ImportError:
    from benchmark import (
        benchmark_forward,
        benchmark_combined,
        benchmark_all,
        benchmark_memory_forward,
        benchmark_memory_combined,
    )

from typing import Callable
import os
import matplotlib.pyplot as plt


class BenchmarkFigureHelper(object):

    def __init__(
        self,
        benchmark_items: list,
        name_2_fn: dict,
        make_input_fn: Callable,
        save_dir: str,
        repeats: int = 10,
        amp: bool = True,
    ):
        full_items_list = ["fwd_tflops", "all_tflops", "memory_fwd", "memory_all"]
        self.name_2_fn = name_2_fn
        self.benchmark_items = benchmark_items
        self.make_input_fn = make_input_fn
        assert all(item in full_items_list for item in benchmark_items)

        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)
        print(f"Saving figures to {self.save_dir}")

        self.repeats = repeats
        self.amp = amp

    def __call__(self, data_list: list, save_name: str, x_label: str):
        x_list = []

        result_dict = {}
        for name in self.name_2_fn.keys():
            result_dict[name] = []

        for data in data_list:
            shape_params = data["shape_params"]
            name = data["name"]
            FLOPS = data["FLOPS"]
            x_list.append(data["x"])
            print("Benchmarking", name)

            inputs = self.make_input_fn(*shape_params)

            for name, fn in self.name_2_fn.items():
                ret = []
                for item in self.benchmark_items:
                    if item == "fwd_tflops":
                        m = benchmark_forward(
                            fn,
                            *inputs,
                            repeats=self.repeats,
                            amp=self.amp,
                        )
                        t_s = m[1].median
                        tflops = (FLOPS * 1e-12) / t_s
                        ret.append(tflops)
                        print(
                            f"{name} - Forward: {t_s * 1e3:.2f} ms, {tflops:.1f} TFLOPS"
                        )
                    elif item == "all_tflops":
                        m = benchmark_combined(
                            fn,
                            *inputs,
                            repeats=self.repeats,
                            amp=self.amp,
                        )
                        t_s = m[1].median
                        tflops = (FLOPS * 3 * 1e-12) / t_s
                        ret.append(tflops)
                        print(
                            f"{name} - Combined: {t_s * 1e3:.2f} ms, {tflops:.1f} TFLOPS"
                        )
                    elif item == "memory_fwd":
                        m = benchmark_memory_forward(
                            fn,
                            *inputs,
                            amp=self.amp,
                        )
                        ret.append(m)
                        print(f"{name} - Forward memory: {m:.2f} GB")
                    elif item == "memory_all":
                        m = benchmark_memory_combined(
                            fn,
                            *inputs,
                            amp=self.amp,
                        )
                        ret.append(m)
                        print(f"{name} - Combined memory: {m:.2f} GB")
                print("-" * 20)
                result_dict[name].append(ret)

        # begin to make figure
        img_path = os.path.join(self.save_dir, f"{save_name}.png")

        num_items = len(self.benchmark_items)

        fig, ax = plt.subplots(num_items, figsize=(20, 10))
        for name, results in result_dict.items():
            for item_i in range(num_items):
                ax[item_i].plot(
                    x_list, [result[item_i] for result in results], label=name
                )

        for item_i in range(num_items):
            ax[item_i].set_xlabel(x_label)
            ylabel = self.benchmark_items[item_i]
            if ylabel == "fwd_tflops":
                ylabel = "Foward PassTFLOPS "
            elif ylabel == "all_tflops":
                ylabel = "Foward + Backward Pass TFLOPS"
            elif ylabel == "memory_fwd":
                ylabel = "Foward Pass Memory (GB)"
            elif ylabel == "memory_all":
                ylabel = "Foward + Backward Pass Memory (GB)"
            ax[item_i].set_ylabel(ylabel)
        plt.legend()
        plt.savefig(img_path)
        plt.close()
        print(f"Saved figure to {img_path}")
