import pickle

import tqdm

from experiments.config import ExperimentConfig
from valuation_func.sampler import SystematicSampler
from valuation_func.semivalues import (
    BanzhafValuationFunc,
    BetaShapleyValuationFunc,
    ShapleyValuationFunc,
)


class Experiment:
    def __init__(
        self,
        fixed_params=None,
        changing_params=None,
        n_runs=1,
        valuation_task_name=None,
        valuation_func=[
            {"func_name": "shapley"},
            {"func_name": "beta_shapley", "alpha": 4, "beta": 1},
            {"func_name": "banzhaf"},
        ],
        save_path=None,
    ):
        self.fixed_params = fixed_params or {}
        self.changing_params = changing_params or {}
        self.n_runs = n_runs
        self.valuation_task_name = valuation_task_name
        self.valuation_func = valuation_func
        self.save_path = save_path

    def run_experiment(self):
        self.values = {}
        exp_config = ExperimentConfig(self.fixed_params, self.changing_params)
        self.configurations = exp_config.get_configurations()
        for func in self.valuation_func:
            func_name = func["func_name"]
            if func_name == "beta_shapley" and "alpha" in func and "beta" in func:
                key = f"({func['alpha']},{func['beta']})-beta_shapley"
            else:
                key = func_name
            self.values[key] = []
        self.marg_contrib_dict = []
        for run in tqdm.tqdm(range(self.n_runs)):
            sampler = SystematicSampler(
                configurations=self.configurations,
                gr_threshold=1.05,
                max_mc_epochs=100,
                random_state=run,
            )
            marg_contrib_dict = sampler.compute_marginal_contributions_for_all(
                verbose=False
            )
            self.marg_contrib_dict.append(marg_contrib_dict)
            for func in self.valuation_func:
                func_name = func["func_name"]
                if func_name == "shapley":
                    evaluator = ShapleyValuationFunc(
                        marg_contrib_dict=marg_contrib_dict
                    )
                    self.values[func_name].append(evaluator.compute_data_values())
                elif func_name == "beta_shapley":
                    evaluator = BetaShapleyValuationFunc(
                        marg_contrib_dict=marg_contrib_dict,
                        alpha=func["alpha"],
                        beta=func["beta"],
                    )
                    self.values[f"({func['alpha']},{func['beta']})-{func_name}"].append(
                        evaluator.compute_data_values()
                    )
                elif func_name == "banzhaf":
                    evaluator = BanzhafValuationFunc(
                        marg_contrib_dict=marg_contrib_dict
                    )
                    self.values[func_name].append(evaluator.compute_data_values())
        if self.save_path is not None:
            with open(f"{self.save_path}_all_values.pkl", "wb") as file:
                pickle.dump(self.values, file)