import copy
import itertools
from typing import Optional, Callable

from ai_control_games.trusted_monitor.tm_blue_team_policy_subtypes import VariableAuditTMBlueTeamPolicy
from ai_control_games.trusted_monitor.tm_game import TMModellingAssumptions, TMGameRules
from ai_control_games.distributions import DistOver0to1, NORMALISED_NORMAL_APPROX_OF_RW, UNIFORM_DIST
from ai_control_games.trusted_monitor.trusted_monitor_calculations import *
from ai_control_games.other_utils import *

from tqdm import tqdm

import pandas as pd


# TODO currently, all betas are found before a random subset is taken, which is slow - change this!
class TMBatchGameManager(ABC):
    indep_var_dict: dict
    df: pd.DataFrame
    _find_opt_beta_func_type = Callable[[TrustedMonitoringGame, TMBlueTeamPolicy], float]

    extra_columns = ()

    @abstractmethod
    def get_blue_team_policy(self, row) -> TMBlueTeamPolicy:
        pass

    def __init__(
            self,
            bd_distribution: DistOver0to1 = NORMALISED_NORMAL_APPROX_OF_RW,
            sf_distribution: DistOver0to1 = UNIFORM_DIST,
            num_problems: int = 100,
            audit_prop: float = 0.02,
            human_FNR: float = 0,
            human_FPR: float = 0,
            should_verify: bool = False,
            _calculate_safety_manager: CalculateSafetyManager = None,
            _calculate_usefulness_manager: CalculateUsefulnessManager = None,
            should_calculate_immediately: bool = True,
            red_br_tol: float = 1e-5,
            red_br_grid_res: int = 32,
    ):
        self.bd_distribution = bd_distribution
        self.sf_distribution = sf_distribution

        self.num_problems = num_problems
        self.audit_prop = audit_prop
        self.human_FNR = human_FNR
        self.human_FPR = human_FPR
        self.should_verify = should_verify
        self._calculate_safety_manager = _calculate_safety_manager
        self._calculate_usefulness_manager = _calculate_usefulness_manager

        self.red_br_tol = red_br_tol
        self.red_br_grid_res = red_br_grid_res

        self._num_samples = None

        self.modelling_assumptions = TMModellingAssumptions(
            bd_dist=self.bd_distribution,
            safe_dist=self.sf_distribution,
            human_FNR=self.human_FNR,
            human_FPR=self.human_FPR
        )

        self.game_rules = TMGameRules(
            audit_proportion=self.audit_prop,
            num_problems=self.num_problems
        )
        self.game = TrustedMonitoringGame(
            ma=self.modelling_assumptions,
            gr=self.game_rules
        )
        if self.indep_var_dict is None:
            raise ValueError("Subclass of TMBatchGameManager must set indep_var_dict in init")

        if "backdoor_prob" in self.indep_var_dict.keys():
            betas = self.indep_var_dict["backdoor_prob"]
            for beta in betas:
                assert not betas is None
                if not isinstance(beta, float) and beta != "opt":
                    raise ValueError("Beta must be a float or 'opt'")

        self.build_dataframe()

        if should_calculate_immediately:
            self.get_exact_safety_calculation()
            self.get_exact_usefulness_calculation()

        self.df["Protocol Type"] = self.df["blue_team_strategy"].apply(lambda x: x.__class__.__name__)

        def get_bts_params(bts):
            if hasattr(bts, "__dataclass_fields__"):
                for attr in bts.__dataclass_fields__.keys():
                    yield getattr(bts, attr)
            else:
                yield None

        self.df["Protocol Params"] = [
            (list(get_bts_params(row["blue_team_strategy"]))
             if (row["blue_team_strategy"] is not None) else None)
            for row in self.df.iloc
        ]

    def create_union(self, other):
        return UnionBatchGameManager(self, other)

    def build_dataframe(self):
        columns = self.indep_var_columns

        combinations = [
            dict(zip(columns, vals))
            for vals in itertools.product(*[self.indep_var_dict[col] for col in columns])
        ]

        self.df = pd.DataFrame(columns=columns, data=combinations)
        for col in self.extra_columns:
            self.df[col] = None

        self.trim_redundant_rows()

        tqdm.pandas(desc="Creating Blue Team Strategies", delay=10)
        btss = self.df.progress_apply(self.get_blue_team_policy, axis=1)

        self.df["blue_team_strategy"] = btss

        tqdm.pandas(desc="Creating Red Team Strategies", delay=10)
        red_team_strats = self.df.progress_apply(self.get_red_team_policy, axis=1)

        self.df["red_team_strategy"] = red_team_strats

    @property
    def indep_var_columns(self):
        return list(self.indep_var_dict.keys())

    def trim_redundant_rows(self):
        pass

    def update_error_columns(self):
        nms = [
            x for x in ["Safety", "Usefulness"]
            if f"{x} Score" in self.df.columns
               and f"MC {x}" in self.df.columns
        ]
        for nm in nms:
            self.df[f"{nm} Error"] = self.df[f"{nm} Score"] - self.df[f"MC {nm}"]
            self.df[f"{nm} Abs Error"] = np.abs(self.df[f"{nm} Error"])
            self.df[f"{nm} Rel Error (err/MC)"] = self.df[f"{nm} Abs Error"] / self.df[f"MC {nm}"]
            self.df[f"{nm} Rel Error (err/Pred)"] = self.df[f"{nm} Abs Error"] / self.df[f"{nm} Score"]

    def check_and_set_num_samples(self, num_samples: int):
        """
        This is a slightly hacky way to set the number of samples for MC estimates for safety and uesfulness.
        It should be replaced in future.
        :param num_samples:
        :return:
        """
        if self._num_samples is not None:
            if self._num_samples != num_samples:
                raise ValueError("num_samples has already been set")
        else:
            self._num_samples = num_samples

    def get_MC_safety_estimates(self, num_samples: int = 100):
        df = self.df
        self.check_and_set_num_samples(num_samples)

        safety_estimates = []
        for row in tqdm(df.iloc, total=len(df), desc="MC Estimating Safety"):
            est, _ = self.game.estimate_MC_safety(
                num_samples,
                row.blue_team_strategy,
                row.red_team_strategy,
            )
            safety_estimates.append(est)

        df["MC Safety"] = safety_estimates
        df["Estimated MC Safety Variance"] = df["MC Safety"] * (1 - df["MC Safety"]) / num_samples
        self.update_error_columns()

    def get_MC_usefulness_estimates(self, num_samples: int = 100):
        self.check_and_set_num_samples(num_samples)
        usefulness_estimates = []
        df = self.df
        for row in tqdm(df.iloc, total=len(df), desc="MC Estimating Usefulness"):
            est, _ = self.game.estimate_MC_usefulness(
                num_samples,
                row.blue_team_strategy,
            )
            usefulness_estimates.append(est)
        self.df["MC Usefulness"] = usefulness_estimates
        self.update_error_columns()

    def get_MC_estimates(self, num_samples: int = 100):
        self.check_and_set_num_samples(num_samples)
        self.get_MC_safety_estimates(num_samples)
        self.get_MC_usefulness_estimates(num_samples)
        self.update_error_columns()

    def get_exact_calculation(self, should_tqdm=True):
        self.get_exact_safety_calculation(should_tqdm)
        self.get_exact_usefulness_calculation(should_tqdm)
        self.update_error_columns()

    def get_exact_safety_calculation(
            self,
            should_tqdm=True
    ):
        if self._calculate_safety_manager is None:
            raise ValueError("Must set _calculate_safety_manager")
        else:
            calc_safety_manager = self._calculate_safety_manager
        self.df["Safety Score"] = [
            calc_safety_manager.calculate_safety(
                self.game,
                row.blue_team_strategy,
                row.red_team_strategy,
                should_verify=self.should_verify
            ) if row.blue_team_strategy is not None else -1.0
            for row in tqdm(self.df.iloc, total=len(self.df), disable=not should_tqdm, delay=10)
        ]
        self.df["Predicted RTWR"] = 1 - self.df["Safety Score"]
        with np.errstate(divide="ignore"):
            self.df["log Predicted RTWR"] = np.log2(self.df["Predicted RTWR"])
        self.update_error_columns()

    def get_exact_usefulness_calculation(
            self,
            should_tqdm=False
    ):
        if self._calculate_usefulness_manager is None:
            raise ValueError("Must set _calculate_usefulness_manager")
        else:
            calc_usefulness_manager = self._calculate_usefulness_manager
        self.df["Usefulness Score"] = [
            calc_usefulness_manager.calculate_usefulness(
                self.game,
                row.blue_team_strategy,
                should_verify=self.should_verify
            ) if row.blue_team_strategy is not None else -1.0
            for row in tqdm(self.df.iloc, total=len(self.df), disable=not should_tqdm)
        ]
        if "Safety Score" in self.df.columns:
            self.df["S+U"] = self.df["Safety Score"] + self.df["Usefulness Score"]
            self.df["S-U"] = self.df["Safety Score"] - self.df["Usefulness Score"]
        self.update_error_columns()

    def assert_MC_vs_exact(self):
        for nm in ["Safety", "Usefulness"]:
            if f"{nm} Score" in self.df.columns and f"MC {nm}" in self.df.columns:
                mc_vals = self.df[f"MC {nm}"].values
                pred_vals = self.df[f"{nm} Score"].values
                diff = np.abs(mc_vals - pred_vals)
                # We have a high tolerance because MC estimates are noisy
                is_close = (diff < 0.1)

                if not is_close.all():
                    failed_rows = self.df[~is_close]

                    important_cols = [f"{nm} Score", f"MC {nm}"]
                    for col in self.indep_var_columns:
                        if len(self.df[col].unique()) > 1:
                            important_cols.append(col)
                    print(failed_rows[important_cols])

                    if nm == "Safety":
                        def get_est_std(row):
                            ps = row[f"Safety Score"]
                            assert self._num_samples is not None
                            var_ests = (ps * (1 - ps) / self._num_samples)
                            std_est = np.sqrt(var_ests)
                            return std_est

                        est_stds = failed_rows.apply(get_est_std, axis=1)
                        print(f"Estimated Variance of Safety Estimate: {est_stds}")
                        abs_errors = np.abs(failed_rows[f"Safety Score"] - failed_rows[f"MC Safety"])
                        num_std_away = abs_errors / est_stds
                        print(f"Number of Standard Deviations Away:")
                        print(num_std_away)

                    raise ValueError(f"MC vs Predicted {nm} not close",
                                     diff[~is_close])

    def get_random_subset(self, num_rows: int):
        bgm_copy = copy.deepcopy(self)
        bgm_copy.df = self.df.sample(num_rows)
        return bgm_copy

    def find_opt_rts(
            self,
            bts: TMBlueTeamPolicy,
    ) -> TMRedTeamPolicy:

        def func_to_minimise(beta):
            rts = SimpleTMRedTeamPolicy(beta=beta)
            return CSM_FLEXI.calculate_safety(game=self.game, bts=bts, rts=rts)

        argmin_beta, safety = grid_then_ternary_search(
            func_to_minimise, 0, 1,
            epsilon=self.red_br_tol,
            grid_res=self.red_br_grid_res,
            f_options={}
        )
        return SimpleTMRedTeamPolicy(beta=argmin_beta)

    def get_red_team_policy(self, row):
        beta = row.backdoor_prob
        if beta != "opt":
            return SimpleTMRedTeamPolicy(
                beta=beta
            )
        else:
            return self.find_opt_rts(
                bts=row.blue_team_strategy,
            )


class UnionBatchGameManager(TMBatchGameManager):
    def __init__(self, bgm1: TMBatchGameManager, bgm2: TMBatchGameManager):
        self._bgm1_get_bts = bgm1.get_blue_team_policy
        self._bgm1_get_bts = bgm1.get_blue_team_policy
        self._bgm2_get_bts = bgm2.get_blue_team_policy

        cols_unique_to_bgm1 = list(set(bgm1.df.columns) - set(bgm2.df.columns))
        cols_unique_to_bgm2 = list(set(bgm2.df.columns) - set(bgm1.df.columns))

        bgm1_df_copy = bgm1.df.copy()
        bgm2_df_copy = bgm2.df.copy()

        for col in cols_unique_to_bgm1:
            bgm2_df_copy[col] = "NA"
        for col in cols_unique_to_bgm2:
            bgm1_df_copy[col] = "NA"

        self.df = pd.concat([bgm2_df_copy, bgm1_df_copy])

        assert bgm1.bd_distribution == bgm2.bd_distribution
        bd_distribution = bgm1.bd_distribution
        assert bgm1.sf_distribution == bgm2.sf_distribution
        sf_distribution = bgm1.sf_distribution

        assert bgm1.num_problems == bgm2.num_problems
        num_problems = bgm1.num_problems
        assert bgm1.audit_prop == bgm2.audit_prop
        audit_prop = bgm1.audit_prop
        assert bgm1.should_verify == bgm2.should_verify
        should_verify = bgm1.should_verify

        self.indep_var_dict = dict(
            [(k, v) for k, v in bgm1.indep_var_dict.items()] +
            [(k, v) for k, v in bgm2.indep_var_dict.items()]
        )

        super().__init__(
            bd_distribution=bd_distribution,
            sf_distribution=sf_distribution,
            num_problems=num_problems,
            audit_prop=audit_prop,
            should_verify=should_verify,
            _calculate_safety_manager=CalcSafetyManagerFlexibleDefault(),
            _calculate_usefulness_manager=CalcUsefulnessManagerFlexibleDefault()
        )

    def get_blue_team_policy(self, row):
        if row["blue_team_strategy"] is None:
            raise ValueError("BTS should already be initialised")

    def build_dataframe(self):
        assert self.df is not None, "UnionBatchGameManager should have df set"
        for row in self.df.iloc:
            if row["blue_team_strategy"] is None:
                raise ValueError("BTS should already be initialised")


class TMBatchGameManager_SimpleBTPolicy(TMBatchGameManager):
    _find_opt_beta_func_type = Callable[[TrustedMonitoringGame, SimpleTMBlueTeamPolicy], float]

    def __init__(
            self,
            bd_distribution: DistOver0to1 = NORMALISED_NORMAL_APPROX_OF_RW,
            sf_distribution: DistOver0to1 = UNIFORM_DIST,
            num_problems: int = 100,
            audit_prop: float = 0.02,
            human_FNR: float = 0,
            human_FPR: float = 0,
            should_verify: bool = False,
            audit_qs: Optional[list[float]] = None,
            defer_qs: Optional[list[float]] = None,
            betas: Optional[list[Optional[float]]] = None,
            _calculate_safety_manager=CalculateSafetyViaRecurrenceForSimpleBTS(),
            _calculate_usefulness_manager=CalcUsefulnessManagerFlexibleDefault(),
            should_calculate_immediately: bool = True,
            red_br_tol: float = 1e-5,
            red_br_grid_res: int = 32,
    ):
        epsilon = 1 / num_problems
        key_sigmoid = [0.0, epsilon, audit_prop, audit_prop * 2,
                       0.5,
                       1.0 - epsilon, 1.0 - audit_prop, 1.0 - (audit_prop * 2), 1.0]

        key_sigmoid = [x for x in key_sigmoid if 0 <= x <= 1]
        key_sigmoid = list(set(key_sigmoid))
        key_sigmoid.sort()

        if audit_qs is None:
            audit_qs = key_sigmoid
        if defer_qs is None:
            defer_qs = key_sigmoid
        if betas is None:
            betas = [0.0, epsilon, 2 * epsilon, "opt"]
        self.indep_var_dict = {
            "Audit q": audit_qs,
            "Defer q": defer_qs,
            "backdoor_prob": betas
        }

        super().__init__(
            bd_distribution=bd_distribution,
            sf_distribution=sf_distribution,
            num_problems=num_problems,
            audit_prop=audit_prop,
            human_FNR=human_FNR,
            human_FPR=human_FPR,
            should_verify=should_verify,
            _calculate_safety_manager=_calculate_safety_manager,
            _calculate_usefulness_manager=_calculate_usefulness_manager,
            should_calculate_immediately=should_calculate_immediately,
            red_br_tol=red_br_tol,
            red_br_grid_res=red_br_grid_res
        )

    def get_blue_team_policy(self, row):
        return SimpleTMBlueTeamPolicy(
            q_audit=row["Audit q"],
            q_defer=row["Defer q"],
            safe_dist=self.sf_distribution
        )


class TMBatchGameManager_SimpleBTPolicy_Simple_or_Adaptive_Audit(TMBatchGameManager):
    _find_opt_beta_func_type = Callable[[TrustedMonitoringGame, VariableAuditTMBlueTeamPolicy], float]

    def __init__(
            self,
            bd_distribution: DistOver0to1 = NORMALISED_NORMAL_APPROX_OF_RW,
            sf_distribution: DistOver0to1 = UNIFORM_DIST,
            num_problems: int = 100,
            audit_prop: float = 0.02,
            human_FNR: float = 0,
            human_FPR: float = 0,
            should_verify: bool = False,
            audit_qs: Optional[list[float]] = None,
            defer_qs: Optional[list[float]] = None,
            betas: Optional[list[Optional[float]]] = None,
            include_simple=True,
            red_br_tol: float = 1e-5,
            red_br_grid_res: int = 32,
    ):
        epsilon = 1 / num_problems
        key_sigmoid = [0.0, epsilon, audit_prop, 0.5, 1.0 - epsilon, 1.0 - audit_prop, 1.0]

        if audit_qs is None:
            audit_qs = key_sigmoid
        if defer_qs is None:
            defer_qs = key_sigmoid
        if betas is None:
            betas = [0.0, epsilon, 2 * epsilon, "opt"]
        self.indep_var_dict = {
            "Audit q": audit_qs,
            "Defer q": defer_qs,
            "backdoor_prob": betas,
            "simple_or_adaptive": ["simple", "adaptive"] if include_simple else ["adaptive"]
        }

        super().__init__(
            bd_distribution=bd_distribution,
            sf_distribution=sf_distribution,
            num_problems=num_problems,
            audit_prop=audit_prop,
            human_FNR=human_FNR,
            human_FPR=human_FPR,
            should_verify=should_verify,
            _calculate_safety_manager=CalculateSafetyViaRecurrenceForDeterministicThresholdBTS(),
            _calculate_usefulness_manager=CalculateUsefulnessViaRecurrenceForDeterministicThresholdBTS(),
            red_br_tol=red_br_tol,
            red_br_grid_res=red_br_grid_res
        )

    def get_blue_team_policy(self, row):
        if row["simple_or_adaptive"] == "simple":
            return SimpleTMBlueTeamPolicy(
                q_audit=row["Audit q"],
                q_defer=row["Defer q"],
                safe_dist=self.sf_distribution
            )
        else:
            return VariableAuditTMBlueTeamPolicy(
                q_audit=row["Audit q"],
                q_defer=row["Defer q"],
                init_audit_proportion=self.audit_prop,
                safe_dist=self.sf_distribution
            )


# TODO separate management of games from management of policies

class TMBatchGameManager_MixedBTPolicy(TMBatchGameManager):
    _find_opt_beta_func_type = Callable[[TrustedMonitoringGame, MixedBlueTeamPolicy], float]

    def __init__(
            self,
            bd_distribution: DistOver0to1 = NORMALISED_NORMAL_APPROX_OF_RW,
            sf_distribution: DistOver0to1 = UNIFORM_DIST,
            num_problems: int = 100,
            audit_prop: float = 0.02,
            human_FNR: float = 0,
            human_FPR: float = 0,
            should_verify: bool = False,
            qa1s: Optional[list[float]] = None,
            qa2s: Optional[list[float]] = None,
            qd1s: Optional[list[float]] = None,
            qd2s: Optional[list[float]] = None,
            ps: Optional[list[float]] = None,
            betas: Optional[list[Optional[float]]] = None,
            red_br_tol: float = 1e-5,
            red_br_grid_res: int = 32,
    ):
        epsilon = 1 / num_problems
        key_sigmoid = [0.0, audit_prop, 0.5, 1.0 - audit_prop, 1.0]

        if qa1s is None:
            qa1s = key_sigmoid
        if qa2s is None:
            qa2s = key_sigmoid
        if qd1s is None:
            qd1s = key_sigmoid
        if qd2s is None:
            qd2s = key_sigmoid
        if ps is None:
            ps = [0.0, 0.25, 0.5]
        if betas is None:
            betas = [0.0, epsilon, 2 * epsilon, "opt"]
        self.indep_var_dict = {
            "qa1": qa1s,
            "qa2": qa2s,
            "qd1": qd1s,
            "qd2": qd2s,
            "p": ps,
            "backdoor_prob": betas
        }

        super().__init__(
            bd_distribution=bd_distribution,
            sf_distribution=sf_distribution,
            num_problems=num_problems,
            audit_prop=audit_prop,
            human_FNR=human_FNR,
            human_FPR=human_FPR,
            should_verify=should_verify,
            _calculate_safety_manager=CalculateSafetyViaRecurrenceForMixedBTS(),
            _calculate_usefulness_manager=CalculateUsefulnessViaRecurrenceForMixedBTS(),
            red_br_tol=red_br_tol,
            red_br_grid_res=red_br_grid_res
        )

    # TODO Maybe create a better automated way to deal with parameters

    def trim_redundant_rows(self):
        # # Remove rows where the same policy is used for both blue teams
        # self.df = self.df[
        #     (self.df["qa1"] != self.df["qa2"]) | (self.df["qd1"] != self.df["qd2"])
        #     ]
        # # IF  row["p"] == 0.0, then set qa2 = qd2 = 0.0
        # self.df.loc[self.df["p"] == 0.0, ["qa2", "qd2"]] = 0.0
        #
        # # Now remove duplicates
        # self.df = self.df.drop_duplicates()
        pass

    def get_blue_team_policy(self, row):
        pols = (
            SimpleTMBlueTeamPolicy(
                q_audit=row["qa1"],
                q_defer=row["qd1"],
                safe_dist=self.sf_distribution
            ),
            SimpleTMBlueTeamPolicy(
                q_audit=row["qa2"],
                q_defer=row["qd2"],
                safe_dist=self.sf_distribution
            )
        )

        return MixedBlueTeamPolicy(
            policies=pols,
            ps=(1 - row["p"], row["p"])
        )
