from ot_jax.data.datasets.DOTmark import DOTmarkClass, DOTmarkResolution
from ot_jax.benchmark.dot_base import benchmark_data, get_output_file, process_tasks

from ot_jax.optimal_transport.jax_wasserstein import (bilevel_upper_bound, weighted_cost_upper_bound, bilevel_lower_bound, min_cost_lower_bound)
from ot_jax.optimal_transport.jax_wasserstein import (entropy_upper_bound, entropy_lower_bound)
from ot_jax.optimal_transport.jax_wasserstein import exact_wasserstein

from pathlib import Path
import jax
from typing import Optional, Callable
import pandas as pd
import time
from itertools import product

from logging import getLogger, basicConfig, INFO
basicConfig(level=INFO)
logger = getLogger(__name__)

class BenchmarkTest:
    __root = Path(__file__).name.replace(".py", "")
    def __init__(self, dot_class: DOTmarkClass, resolution: DOTmarkResolution, p, ot_fun, **param):
        self.dot_class = dot_class
        self.resolution = resolution
        self.p = p
        self.ot_fun = ot_fun
        
        self.param = None
        if param:
            self.param_name = list(param.keys())[0]
            self.param = param[self.param_name]
    @property
    def path(self):
        return get_output_file(self.__root, self.dot_class, self.resolution, self.p, self.ot_fun, self.param)
    
    def process(self, data: dict, func: Callable):
        process_tasks(data, self.p, self.ot_fun, self.param, self.path, func)
        
    def load(self):
        if self.param:
            return pd.read_csv(self.path, header=None, names=["class", "resolution", "ot_fun", "p", "i", "j", self.param_name, "ot","time"])
        return pd.read_csv(self.path, header=None, names=["class", "resolution", "ot_fun", "p", "i", "j", "ot","time"])
        

        

FILE = Path(__file__).name.replace(".py", "")

def compute_benchmark(dot_class: DOTmarkClass, res: DOTmarkResolution, i: int, j: int, x: jax.Array, y: jax.Array, p: int, ot_fun: Callable, param: Optional[int | float]):    
    # Compute the optimal transport value using the desired algorithm
    if param:
        start = time.time()
        ot = jax.block_until_ready(ot_fun(x, y, p, param))
        duration = time.time() - start
        return f"{dot_class},{res},{ot_fun.__name__},{p},{i},{j},{param},{ot},{duration}\n"
    
    start = time.time()
    ot = jax.block_until_ready(ot_fun(x, y, p))
    duration = time.time() - start
    return f"{dot_class},{res},{ot_fun.__name__},{p},{i},{j},{ot},{duration}\n"


def main(data: dict, dot_class_list: list[DOTmarkClass], resolution_list: list[DOTmarkResolution]):
    # Define the parameters for the approximate optimal transport computation
    p = [1, 2]
    df_scale = []
    df_entropy = []
    df_exact = []

    for dot_class, resolution in product(dot_class_list, resolution_list): 
        ot_fun = [bilevel_upper_bound, weighted_cost_upper_bound, bilevel_lower_bound, min_cost_lower_bound]
        scale_factor = [2, 4]
        if resolution > DOTmarkResolution.MEDIUM:
            scale_factor = [4, 8]
        
        benchmark_scale = BenchmarkTest(dot_class, resolution, p, ot_fun, kappa=scale_factor)
        benchmark_scale.process(data, compute_benchmark)
        
        # Define the parameters for the entropy optimal transport computation
        ot_fun = [entropy_upper_bound, entropy_lower_bound]
        
        benchmark_entropy = BenchmarkTest(dot_class, resolution, p, ot_fun, epsilon_factor=[1e-3, 4e-3])
        benchmark_entropy.process(data, compute_benchmark)
        
        # Define the parameters for the exact optimal transport computation
        ot_fun = [exact_wasserstein]
        
        benchmark_exact = BenchmarkTest(dot_class, resolution, p, ot_fun)
        benchmark_exact.process(data, compute_benchmark)
        
        # Analyze the results
        df_scale.append(benchmark_scale.load())
        df_entropy.append(benchmark_entropy.load())
        df_exact.append(benchmark_exact.load())

    name = '_'.join([res.name for res in resolution_list]) + "_" + '_'.join([cls.name for cls in dot_class_list])
    df_scale = pd.concat(df_scale)
    df_scale["ot_fun_full"] = df_scale["ot_fun"] + r"(\kappa=" + df_scale["kappa"].astype(str) + ")"
    
    df_entropy = pd.concat(df_entropy)
    df_entropy["epsilon"] = df_entropy["epsilon_factor"] *  df_entropy["resolution"] **  df_entropy["p"]
    df_entropy["ot_fun_full"] = df_entropy["ot_fun"] + r"(\varepsilon=" + df_entropy["epsilon_factor"].astype(str) + "N^p)"
    
    df_exact = pd.concat(df_exact)
    
    df = pd.concat([df_scale, df_entropy]).merge(df_exact, on=["class", "resolution", "p", "i", "j"], suffixes=("_approx", "_exact"))
    df["time_rel"] = df["time_approx"] / df["time_exact"]
    df["error"] = abs(df["ot_approx"] - df["ot_exact"])
    df["error_rel"] = df["error"] / df["ot_exact"]
    df["bound"] = df["ot_fun_approx"].apply(lambda x: "upper" if "upper" in x else "lower")

    df.to_csv(Path(Path(__file__).name.replace(".py", "") + f"_full_{name}.csv").resolve())
    df.to_csv(Path(Path(__file__).name.replace(".py", "") + f"_full_LATEST.csv").resolve())

    df_time = (
        df
        .groupby(["class", "resolution", "ot_fun_full", "p", "bound"])["time_rel"]
        .agg(["mean", "std"])
        .reset_index()
        .pivot_table(index=["p", "bound", "ot_fun_full"], columns=["resolution"], values=["mean", "std"])
        .swaplevel(axis='columns')
        .sort_index(axis=1, level=[0, 1])
        )
    df_time.to_csv(Path(Path(__file__).name.replace(".py", "") + f"_time_{name}.csv").resolve())
    df_time.to_csv(Path(Path(__file__).name.replace(".py", "") + f"_time_LATEST.csv").resolve())
    
    df_acc = (
        df
        .groupby(["class", "resolution", "ot_fun_full", "p", "bound"])["error_rel"]
        .agg(["mean", "std"])
        .reset_index()
        .pivot_table(index=["p", "bound", "ot_fun_full"], columns=["class"], values=["mean", "std"])
        .swaplevel(axis='columns')
        .sort_index(axis=1, level=[0, 1])
    )
    df_acc.to_csv(Path(Path(__file__).name.replace(".py", "") + f"_accuracy_{name}.csv").resolve())
    df_acc.to_csv(Path(Path(__file__).name.replace(".py", "") + f"_accuracy_LATEST.csv").resolve())
    
    return df_time, df_acc
             
if __name__ == "__main__":
    # Define the parameters for the benchmark data
    dot_class=[DOTmarkClass.Microscopy_Images, DOTmarkClass.Shapes, DOTmarkClass.Classic_Images]
    # resolution = [DOTmarkResolution.MEDIUM, DOTmarkResolution.LOW]
    # dot_class = [DOTmarkClass.Microscopy_Images]
    resolution = [DOTmarkResolution.MEDIUM]
    
    data = benchmark_data(dot_class=dot_class, resolution=resolution)
    print(main(data, dot_class, resolution))