from datasets.datasets import Dataset
import os, json
import numpy as np
from typing import Literal
from models.mlp_pytorch import PyTorchMLP
from models.gradientboosting_sklearn import GradientBoostingSKLearn
from models.randomforest_sklearn import RandomForestSKLearn
from collections import defaultdict

from counterfactual_explanations.milp_based.cf_conformal import ConformalCF
from counterfactual_explanations.milp_based.cf_mindist import MinDistanceCF

from conformal.localised_conformal_baselcp import *
from conformal.localised_conformal_slcp import *
from conformal.split_conformal import *
from conformal.localised_conformal_others import *

from counterfactual_explanations.benchmarker.factories import *
from counterfactual_explanations.benchmarker.metrics import *

import matplotlib.pyplot as plt

from counterfactual_explanations.dim_reduction import DimensionalityReduction
from tqdm import tqdm

class CFBenchmarker:
    def __init__(self, dataset, n_factuals_main, n_repeats, metrics, model_factories, generator_factories, save_dir="experiments", use_pretrained=True, id=0):
        self.dataset = dataset
        self.n_factuals_main = n_factuals_main
        self.n_repeats = n_repeats
        self.use_pretrained = use_pretrained
        self.metrics = metrics
        self.model_factories = model_factories
        self.generator_factories = generator_factories

        self.X_train, self.y_train, self.X_calib, self.y_calib, self.X_test, self.y_test = self.dataset.get_X_y_split()
        self.save_dir = save_dir / str(id) / self.dataset.get_name()

        self.models_evaluation_path = self.save_dir / "model_evaluation.json"
        self.factuals_path = self.save_dir / "factuals.json"
        self.counterfactuals_path = self.save_dir / "counterfactuals.json"
        self.eval_path_raw = self.save_dir / "evaluation_raw.json"
        self.eval_path_table = self.save_dir / "evaluation_table.txt"
        self.eval_path_table_2 = self.save_dir / "evaluation_table_2.txt"
        self.figs_save_dir = self.save_dir / "figures"
        self.generators_dir = self.save_dir / "generators"
        self.models_dir = self.save_dir / "models"
        self.conformal_eval_raw_path = self.save_dir / "conformal_eval.json"
        self.conformal_eval_text_path = self.save_dir / "conformal_eval.txt"
        self.additional_conformal = {}

        os.makedirs(self.save_dir, exist_ok=True)
        os.makedirs(self.figs_save_dir, exist_ok=True)
        os.makedirs(self.generators_dir, exist_ok=True)
        os.makedirs(self.models_dir, exist_ok=True)

    def get_dataset(self):
        return self.dataset

    def setup_models(self):
        all_models = {}

        for factory in self.model_factories:
            models = factory.train_models(self.X_train, self.y_train, self.n_repeats, self.models_dir, self.use_pretrained)
            
            for model in models:
                all_models[str(model.save_dir)] = model

        self.all_models = all_models
 
    def evaluate_models(self):
        all_evaluations = {}

        for pathname, model in self.all_models.items():
            evaluation = model.load_or_save_evaluation(self.X_test, self.y_test, use_pretrained=self.use_pretrained)
            all_evaluations[pathname] = evaluation
        
        with open(self.models_evaluation_path, 'a') as f:
            json.dump(all_evaluations, f, indent=4)


    def set_factuals(self):
        all_factuals = {}

        for pathname, model in self.all_models.items():
            factuals_path = Path(pathname) / "factuals.json"

            if factuals_path.is_file() and self.use_pretrained:
                with open(factuals_path, 'r') as f:
                    factuals = json.load(f)
                    all_factuals[pathname] = factuals

            else:
                seed = model.random_state
                factuals_bank = {}

                x_factuals, y_factuals = self.dataset.sample_dataset(self.n_factuals_main, seed=seed)
                y_target = define_counterfactual_targets(x_factuals, model, self.dataset.input_properties.n_targets)
                
                factuals_bank['main'] = (x_factuals.tolist(), y_target.tolist())

                for metric in self.metrics:
                    cf_bank = metric.get_factuals_bank(model, self.dataset.input_properties, self.dataset, factuals_bank, seed)
                    if cf_bank is not None:
                        key, X_factuals, y_targets = cf_bank
                        factuals_bank[key] = (X_factuals.tolist(), y_targets.tolist())

                all_factuals[pathname] = factuals_bank

                with open(factuals_path, 'w') as f:
                    json.dump(factuals_bank, f, indent=4)
        
        return all_factuals

    def initialise_generators(self):
        self.model_generators = defaultdict(list)

        for pathname, model in self.all_models.items():
            for generator_factory in self.generator_factories:
                generators = generator_factory.setup_generators(model, self.dataset.input_properties, self.X_train, self.y_train, self.X_calib, self.y_calib, self.generators_dir, self.use_pretrained)
                self.model_generators[pathname].extend(generators)

    def get_counterfactuals(self, reset=False):
        factuals = self.set_factuals()

        counterfactuals_output = {}

        for model_pathname, generators in self.model_generators.items():
            counterfactuals_path = Path(model_pathname) / "counterfactuals.json"
            factuals_bank = factuals[model_pathname]

            model_counterfactuals = {}
            if counterfactuals_path.is_file():
                with open(counterfactuals_path, 'r') as f:
                    model_counterfactuals = json.load(f)

            for generator in generators:
                print(generator.name())
                counterfactuals_bank = {}

                if model_counterfactuals.get(generator.name()) != None and not reset and self.use_pretrained:
                    print(f"Using saved for {model_pathname}-{generator.name()}")
                    counterfactuals_bank = model_counterfactuals.get(generator.name())
                else:
                    for bank_name, bank_value in factuals_bank.items():
                        bank_factuals, bank_targets = bank_value
                        
                        bank_factuals = np.array(bank_factuals)
                        bank_targets = np.array(bank_targets).astype(int)

                        #TODO add timing
                        counterfactuals = generator.generate_counterfactuals(bank_factuals, bank_targets)
                        counterfactuals_bank[bank_name] = counterfactuals.tolist()
                        model_counterfactuals[generator.name()] = counterfactuals_bank

                        with open(counterfactuals_path, 'w') as f:
                            json.dump(model_counterfactuals, f, indent=4)

                model_counterfactuals[generator.name()] = counterfactuals_bank

            counterfactuals_output[model_pathname] = model_counterfactuals

        return counterfactuals_output

    def evaluate_counterfactuals(self, aggregate_means=False):
        factuals_output = self.set_factuals()
        counterfactuals_output = self.get_counterfactuals()

        model_generator_metrics = {}

        for model_factory in self.model_factories:
            for model_name, model_set in model_factory.get_models_over_repeats().items():
                generator_metrics = defaultdict(list)
                
                for model in model_set:
                    counterfactuals = counterfactuals_output[str(model.save_dir)]
                    factuals_bank = factuals_output[str(model.save_dir)]
                    factuals_bank = {k: (np.array(v[0]), np.array(v[1])) for k, v in factuals_bank.items()}


                    for generator_name, counterfactuals_bank in counterfactuals.items():
                        if generator_metrics.get(generator_name) is None:
                            generator_metrics[generator_name] = defaultdict(list)

                        for metric in self.metrics:
                            metric_results = metric.compute_metric(model, self.dataset.input_properties, self.dataset, factuals_bank, counterfactuals_bank)
                            if isinstance(metric_results, np.ndarray) and not aggregate_means:
                                mean_result = np.nanmean(metric_results)    
                                generator_metrics[generator_name][metric.name()].append(mean_result)
                            else:
                                generator_metrics[generator_name][metric.name()].append(metric_results.tolist())

                combined_metrics = defaultdict(dict)
                for generator, generator_results in generator_metrics.items():
                    for metric, metric_results in generator_results.items():
                        combined_metrics[generator][metric] = {
                            "mean": np.mean(metric_results),
                            "sd": np.std(metric_results)
                        }

                model_generator_metrics[str(model_name)] = {"raw": generator_metrics, "aggregated": combined_metrics}


        with open(self.eval_path_raw, 'w') as f:
            json.dump(model_generator_metrics, f, indent=4)

        print("Writing results to files")

        dfs = []

        with open(self.eval_path_table, 'w') as f:
            f.write(f"Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            for model_type, metrics in model_generator_metrics.items():
                f.write(model_type + "\n")
                df = pd.DataFrame(metrics['aggregated'])
                dfs.append(df)
                df_formatted = df.applymap(lambda x: f"{x['mean']:.3f} ± {x['sd']:.3f}" if isinstance(x, dict) else x)
                z = df_formatted.T
                f.write(z.to_latex())
                f.write('\n\n')
        
        self.model_generator_metrics = model_generator_metrics

        return dfs
    
    def generate_table(self, 
                       distance_metric=None, plausibility_metric=None, implausibility_metric=None, sensitivity_metric=None, stability_metric=None,
                       generators_predicate=lambda gen: True, validity_metric=None, failures_metric=None, report_failures=True, report_invalidity=True, dp2=False, scaling=[],
                       include_extra=[]):

        if not self.eval_path_raw.is_file():
            print("Run evaluate_counterfactuals first!")
            return
        
        with open(self.eval_path_raw) as f:
            model_generator_metrics = json.load(f)


        metrics_objs = {
            "distance_metric": distance_metric or next((m.name() for m in self.metrics if isinstance(m, DistanceMetric)), None),
            "plausibility_metric": plausibility_metric or next((m.name() for m in self.metrics if isinstance(m, LOFMetric)), None),
            "implausibility_metric": implausibility_metric or next((m.name() for m in self.metrics if isinstance(m, ImplausibilityMetric)), None),
            "sensitivity_metric": sensitivity_metric or next((m.name() for m in self.metrics if isinstance(m, SensitivityMetric)), None),
            "stability_metric": stability_metric or next((m.name() for m in self.metrics if isinstance(m, StabilityMetric)), None),
            "failures_metric": failures_metric or next((m.name() for m in self.metrics if isinstance(m, FailuresMetric)), None),
            "validity_metric": validity_metric or next((m.name() for m in self.metrics if isinstance(m, ValidityMetric)), None)
        }

        metrics_cols = [metrics_objs['distance_metric'], metrics_objs['plausibility_metric'], metrics_objs['implausibility_metric'], metrics_objs['sensitivity_metric'], metrics_objs['stability_metric']]
        metrics_cols = [col for col in metrics_cols if col is not None]
        metrics_cols += include_extra

        dfs = []

        with open(self.eval_path_table_2, 'w') as f:
            f.write(f"Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            for model_type, metrics in model_generator_metrics.items():
                f.write(model_type + "\n")
                df = pd.DataFrame(metrics['aggregated']).T
                
                if metrics_objs['failures_metric'] and report_failures:
                    failures = df['Failures'].map(lambda x: x['mean'])
                    failures_str = json.dumps(dict(failures[failures >= 0.01]))
                    f.write('Failures\n')
                    f.write(failures_str)
                    f.write('\n')

                if metrics_objs['validity_metric'] and report_invalidity:
                    validity = df['Validity'].map(lambda x: x['mean'])
                    validity_str = json.dumps(dict(validity[validity <= 0.99]))
                    f.write('Validity\n')
                    f.write(validity_str)
                    f.write('\n')

                df = df.loc[df.index[df.index.map(generators_predicate)], metrics_cols]

                for col_num, scaling_factor in scaling:
                    col_name = df.columns[col_num]
                    df[col_name] = df[col_name].apply(
                        lambda x: {"mean": x["mean"] * scaling_factor, "sd": x["sd"] * scaling_factor} if isinstance(x, dict) else x
                    )

                if dp2:
                    df_formatted = df.applymap(lambda x: f"{x['mean']:.2f} ± {x['sd']:.2f}" if isinstance(x, dict) else x)
                else:
                    df_formatted = df.applymap(lambda x: f"{x['mean']:.3f} ± {x['sd']:.3f}" if isinstance(x, dict) else x)

                f.write(df_formatted.to_latex())
                f.write('\n\n')

                dfs.append(df_formatted)

        return dfs


    def generate_figure(self, model_key, metric, x_axis_config, include_sep=True, save=False, exclude_predicate=lambda gen_name: False):
        metrics_agg = pd.DataFrame(self.model_generator_metrics[model_key]['aggregated'])

        means = metrics_agg.applymap(lambda x: x['mean'] if isinstance(x, dict) else x)
        sds = metrics_agg.applymap(lambda x: x['sd'] if isinstance(x, dict) else x)

        means, sds = means.loc[metric], sds.loc[metric]

        y_values = defaultdict(list)
        y_errors = defaultdict(list)
        
        sep_y_values = defaultdict(list)
        sep_y_errors = defaultdict(list)

        for generator_factory in self.generator_factories:
            for generator in generator_factory.get_generators():
                config = generator.config

                if exclude_predicate(generator.name()):
                    continue


                metric_mean = means[generator.name()]
                metric_std = means[generator.name()]

                x_axis_config_val = config.get(x_axis_config, None)

                if x_axis_config_val is None:
                    try:
                        keys = x_axis_config.split(".")
                        x_axis_config_val = config[keys[0]][keys[1]]
                    except:
                        pass
                
                if x_axis_config_val is None:
                    sep_y_values[generator.name()] = metric_mean
                    sep_y_errors[generator.name()] = metric_std
                else:
                    y_values[generator.name(exclude=x_axis_config)].append([x_axis_config_val, metric_mean])
                    y_errors[generator.name(exclude=x_axis_config)].append([x_axis_config_val, metric_std])

        plt.figure(figsize=(20, 6))

        # Together plot
        plt.subplot(1, 1 if not include_sep else 2, 1)
        for line_key, points in y_values.items():
            points = np.array(points)
            x_vals, y_vals = points[:, 0], points[:, 1]

            errors = np.array(y_errors[line_key])
            _, y_errs = errors[:, 0], errors[:, 1]

            label = f"{line_key}" if line_key is not None else "Default"
            plt.errorbar(x_vals, y_vals, yerr=y_errs, label=label, capsize=5, marker='o', linestyle='-')

        plt.xlabel(x_axis_config)
        plt.ylabel(metric)
        plt.title(f"{metric} vs {x_axis_config}")
        plt.legend()
        plt.grid(True)

        if include_sep and len(sep_y_values) > -1:
            # Sep plot
            plt.subplot(1, 2, 2)
            x_vals = list(sep_y_values.keys())
            y_vals = [sep_y_values[key] for key in x_vals]
            y_errs = [sep_y_errors[key] for key in x_vals]

            plt.errorbar(x_vals, y_vals, yerr=y_errs, capsize=5, marker='o', linestyle='-', label="Separate Generators")
            plt.xticks(rotation=45)
            plt.xlabel("Generator")
            plt.title(f"{metric} for Separate Generators")
            plt.grid(True)

        plt.tight_layout()

        if save:
            plt.savefig(self.figs_save_dir / f"{model_key}_{metric}_vs_{x_axis_config}.png")
        else:
            plt.show()

    def get_figure(self, model_num, metric, x_axis_config, include_sep=True, exclude_predicate=lambda gen_name: False):
        model_key = list(self.model_generator_metrics.keys())[model_num]
        self.generate_figure(model_key, metric, x_axis_config, include_sep=include_sep, save=False, exclude_predicate=exclude_predicate)

    def save_figures(self, include_sep=True, exclude_predicate=lambda gen_name: False):
        config_multi_keys = []
        for generator_factory in self.generator_factories:
            config_multi_keys.extend(generator_factory.config_multi.keys())

        all_metrics = pd.DataFrame(next(iter(self.model_generator_metrics.values()))['aggregated']).index

        for model_key in self.model_generator_metrics.keys():
            for config_key in config_multi_keys:
                for metric in all_metrics:
                    self.generate_figure(model_key, metric, config_key, include_sep=include_sep, save=True, exclude_predicate=exclude_predicate)


    def get_means_sds(self, df):
        metrics_agg = pd.DataFrame(df)
        means = metrics_agg.applymap(lambda x: x['mean'] if isinstance(x, dict) else x)
        sds = metrics_agg.applymap(lambda x: x['sd'] if isinstance(x, dict) else x)
        return means.T, sds.T
    
    def set_additional_conformal(self, conformal_classes, conformal_config, conformal_config_multi):
        for model_path, model in self.all_models.items():
            model_conformals = []

            for vals in product(*conformal_config_multi.values()):
                config = dict(zip(conformal_config_multi.keys(), vals))
                config = conformal_config | config

                for conformal_cls in conformal_classes:
                    conformal = conformal_cls(model, self.dataset.input_properties, config=config, save_path=self.generators_dir, use_pretrained=self.use_pretrained)

                    if conformal.dim_reduction:
                        conformal.dim_reduction.setup(model, self.dataset.input_properties, self.dataset.X_train, self.dataset.y_train, self.generators_dir, self.use_pretrained)

                    conformal.calibrate(self.dataset.X_calib, self.dataset.y_calib)
                    model_conformals.append(conformal)

            self.additional_conformal[model_path] = model_conformals
        

    def test_conformal(self, write_to_file=True):
        model_generators = self.model_generators

        metrics_key = ["Marginal", "Class Conditional", "Random Binning", "Counterfactual Sim"]

        model_results = {}

        for model_factory in self.model_factories:
            dfs = {}

            for model_desc, models in model_factory.get_models_over_repeats().items():
                model_metrics_size = {}
                model_metrics_covgap = {}

                for model in models:
                    model_name = str(model.save_dir)
                    
                    generators = model_generators[model_name]
                    conformals = [g.conformal for g in generators if isinstance(g, ConformalCF)]
                    
                    if self.additional_conformal.get(model_name):
                        conformals += self.additional_conformal[model_name]

                    for conformal in tqdm(conformals, desc=f"Evaluating conformals", leave=False):
                        conformal_name = conformal.name()

                        set_size_m, coverage_gap_m, set_size_cc, coverage_gap_cc, set_size_rb, coverage_gap_rb, set_size_cf, coverage_gap_cf = conformal.evaluate_conditional(self.X_test, self.y_test, cov_gap=True)

                        model_metrics_size[(conformal_name, model_name)] = [set_size_m, set_size_cc, set_size_rb, set_size_cf]
                        model_metrics_covgap[(conformal_name, model_name)] = [coverage_gap_m, coverage_gap_cc, coverage_gap_rb, coverage_gap_cf]

                model_dfs = []

                for data_dict in (model_metrics_size, model_metrics_covgap):
                    df = pd.DataFrame(data_dict)
                    df = df.set_index(pd.MultiIndex.from_product([metrics_key], names=["metrics"]))   
                    df_grouped = df.T.groupby(level=0).mean()
                    df_grouped_sd = df.T.groupby(level=0).std()
                    model_dfs.append((df_grouped, df_grouped_sd))

                dfs[str(model_desc)] = {"size": {"mean": model_dfs[0][0], "sd": model_dfs[0][1]}, "covgap": {"mean": model_dfs[1][0], "sd": model_dfs[1][1]}}
            
            model_results |= dfs
        
        if write_to_file:
            with open(self.conformal_eval_text_path, 'a') as f:
                for model_desc, results in dfs.items():
                    mean_df = results["size"]["mean"]
                    sd_df = results["size"]["sd"]
                    formatted_df = mean_df.applymap(lambda x: f"{x:.3f}") + " ± " + sd_df.applymap(lambda x: f"{x:.3f}")

                    mean_df_cg = results["covgap"]["mean"]
                    sd_df_cg = results["covgap"]["sd"]
                    formatted_df_cg = mean_df_cg.applymap(lambda x: f"{x:.3f}") + " ± " + sd_df_cg.applymap(lambda x: f"{x:.3f}")

                    f.write(f"{model_desc}\n")
                    f.write(f"Average set size\n")
                    f.write(formatted_df.to_latex())
                    f.write(f"Coverage gap\n")
                    f.write(formatted_df_cg.to_latex())
                    f.write("\n\n")


        return model_results










