import itertools
from abc import ABC, abstractmethod
from typing import NamedTuple, List, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm

from ai_control_games.generic_game import BlueTeamPolicy, RedTeamStrategy, Game, MixedBlueTeamPolicy
from ai_control_games.other_utils import get_leader_follower_eq_from_s_u_and_weight

HAS_WARNED_ALREADY_GENERIC_OPTIMISATION = False


def maybe_warn_generic_optimisation(message: str):
    global HAS_WARNED_ALREADY_GENERIC_OPTIMISATION
    if not HAS_WARNED_ALREADY_GENERIC_OPTIMISATION:
        print(message)
        HAS_WARNED_ALREADY_GENERIC_OPTIMISATION = True


class SingleObjectiveOptimiserOutput(NamedTuple):
    weight: float
    blue_team_strategy: BlueTeamPolicy
    red_team_strategy: RedTeamStrategy
    single_objective_value: float
    safety: float
    usefulness: float
    info: dict


class LoopInMaxMin(Exception):
    pass


class IterLimitExceeded(Exception):
    pass


class AllEdgesSplit(Exception):
    pass


class FrontierExplorer(ABC):
    blue_br_tol: float
    red_br_tol: float
    equilibrium_tol: float

    @abstractmethod
    def get_safety(self, game: Game, blue_pol: BlueTeamPolicy, red_pol: RedTeamStrategy) -> float:
        raise NotImplementedError

    @abstractmethod
    def get_usefulness(self, game: Game, blue_pol: BlueTeamPolicy) -> float:
        raise NotImplementedError

    @abstractmethod
    def get_initial_red_team_strategy(self, game: Game) -> RedTeamStrategy:
        raise NotImplementedError

    @abstractmethod
    def get_initial_blue_team_policy(self, game: Game) -> BlueTeamPolicy:
        raise NotImplementedError

    @abstractmethod
    def get_value_and_blue_best_response(self, game: Game, weight: float, red_pol: RedTeamStrategy) -> \
            tuple[float, BlueTeamPolicy]:
        """
        Given a red team strategy, find the best blue team strategy in terms of safety and usefulness.
        Return value = weight * safety + (1 - weight) * usefulness 
        Return blue_pol = the best blue team strategy.
        :param weight:
        :param game: The game to optimise over.
        :param red_pol: The red team strategy.
        :return: (value, blue_pol) where blue_pol is the best blue team strategy.
        """
        raise NotImplementedError

    @abstractmethod
    def get_safety_and_red_best_response(self, game: Game, blue_pol: BlueTeamPolicy) -> Tuple[float, RedTeamStrategy]:
        raise NotImplementedError

    @abstractmethod
    def get_single_obj_equilibrium(self, game: Game, weight: float) -> SingleObjectiveOptimiserOutput:
        raise NotImplementedError

    @staticmethod
    def find_weight_for_largest_gap(xys: List, edges_split_so_far: set) -> Tuple:
        # TODO: We could reduce redunancy by storing as a weighted graph but this code is not a bottleneck
        xys.sort(key=lambda xy: xy[0])
        edges = []
        for i in range(len(xys) - 1):
            (x1, y1), (x2, y2) = xys[i], xys[i + 1]
            if x1 == x2 and y1 == y2:
                raise ValueError("Repeated point found", xys)
            elif x1 == x2 or y1 == y2:
                maybe_warn_generic_optimisation(
                    f"Warning: weakly-dominated point found in frontier {(x1, y1), (x2, y2)}")
                continue
            elif ((x1, y1), (x2, y2)) in edges_split_so_far:
                maybe_warn_generic_optimisation(f"Warning: attempted edge split twice (skipping) {(x1, y1), (x2, y2)}")
            else:
                edges.append(((x1, y1), (x2, y2)))
        if len(edges) == 0:
            raise AllEdgesSplit("No edges found", xys)
        distances_squared = [(x2 - x1) ** 2 + (y2 - y1) ** 2
                             for (x1, y1), (x2, y2) in edges]
        assert all(d > 0 for d in distances_squared), distances_squared
        largest_gap_index = np.argmax(distances_squared)
        (x1, y1), (x2, y2) = edges[largest_gap_index]
        # Find w such that if f(x,y)=wx + (1-w)y, then f(x1,y1) = f(x2,y2)
        weight = (y1 - y2) / (y1 - y2 + x2 - x1)
        assert 0 < weight < 1, (weight, x1, y1, x2, y2)
        return weight, (x1, y1), (x2, y2)

    def explore_frontier(self, game: Game, num_points_along_frontier: int, name=None, verbose=False) -> List[SingleObjectiveOptimiserOutput]:
        frontier_points = []
        s_u_pairs = []
        edges_split_so_far = set()

        for idx in tqdm(range(num_points_along_frontier), delay=5, desc="Calculating Frontier Points"):
            if len(frontier_points) == 0:
                weight = 0.0
            elif len(frontier_points) == 1:
                weight = 1.0
            else:
                try:
                    weight, xy1, xy2 = FrontierExplorer.find_weight_for_largest_gap(s_u_pairs, edges_split_so_far)
                    edges_split_so_far.add((xy1, xy2))
                except AllEdgesSplit as e:
                    print("Warning: skipping all edges split, ending early", e)
                    break
            if verbose:
                print(weight)
            # TODO analyse this behaviour poperly!
            try:
                out = self.get_single_obj_equilibrium(game, weight)
            except LoopInMaxMin as e:
                print("Warning: skipping loop in maxmin", e)
                continue
            except IterLimitExceeded as e:
                print("Warning: skipping iteration limit exceeded", e)
                continue

            if (out.safety, out.usefulness) in s_u_pairs:
                print("Warning: skipping repeated point found in frontier", out.safety, out.usefulness)
                continue
            elif any((out.safety < s and out.usefulness < u) for s, u in s_u_pairs):
                dominating_s_u = [(s, u) for s, u in s_u_pairs if out.safety < s and out.usefulness < u]
                raise ValueError(f"Found point dominated by {dominating_s_u} in frontier {out}")
            else:
                frontier_points.append(out)
                s_u_pairs.append((out.safety, out.usefulness))
            
            if idx % 12 == 0:
                identifier = "" if name is None else name
                pd.DataFrame(frontier_points).to_csv(f"./data/checkpoints/frontier_for_{identifier}_{idx+1}_points", index=False)

        return frontier_points


class FrontierExplorerWithGridSearch(FrontierExplorer, ABC):
    def get_initial_red_team_strategy(self, game: Game) -> RedTeamStrategy:
        raise NotImplementedError("Not needed for grid search")

    def get_initial_blue_team_policy(self, game: Game) -> BlueTeamPolicy:
        raise NotImplementedError("Not needed for grid search")

    def get_value_and_blue_best_response(self, game: Game, weight: float, red_pol: RedTeamStrategy):
        raise NotImplementedError("Not needed for grid search")

    def get_single_obj_equilibrium(self, game: Game, weight: float) -> SingleObjectiveOptimiserOutput:
        assert 0.0 <= weight <= 1.0, weight
        best_v_s_u_bts_rts = (float("-inf"), None)
        assert len(self.blue_team_policies) > 0
        for blue_team_policy in tqdm(self.blue_team_policies, delay=5, desc="Iterating over Blue policies"):
            safety, red_respones = self.get_safety_and_red_best_response(game, blue_team_policy)
            usefulness = self.get_usefulness(game, blue_team_policy)
            value = weight * safety + (1 - weight) * usefulness
            if value > best_v_s_u_bts_rts[0]:
                best_v_s_u_bts_rts = (value, safety, usefulness, blue_team_policy, red_respones)

        best_value, safety, usefulness, best_blue_team_policy, best_red_response = best_v_s_u_bts_rts
        return SingleObjectiveOptimiserOutput(
            weight=weight,
            blue_team_strategy=best_blue_team_policy,
            red_team_strategy=best_red_response,
            single_objective_value=best_value,
            safety=self.get_safety(game, best_blue_team_policy, best_red_response),
            usefulness=self.get_usefulness(game, best_blue_team_policy),
            info=dict()
        )

    def __init__(self, blue_team_policies: List[BlueTeamPolicy]):
        self.blue_team_policies = blue_team_policies


class FrontierExplorerWithIterativeDeterministicEquilibriumFinding(FrontierExplorer, ABC):
    def get_initial_blue_team_policy(self, game: Game) -> BlueTeamPolicy:
        raise NotImplementedError("Not needed for deterministic equilibrium finding")
    
    def get_single_obj_equilibrium(self, game: Game, weight: float) -> SingleObjectiveOptimiserOutput:
        t = 0
        red_team_strategies = [self.get_initial_red_team_strategy(game)]
        blue_team_strategies = []
        delta = float("inf")

        while delta > self.equilibrium_tol:
            last_rts = red_team_strategies[-1]
            s_upper, u, next_bts = self.get_value_and_blue_best_response(game, weight, last_rts)

            v_upper = (weight * s_upper) + (1 - weight) * u
            blue_team_strategies.append(next_bts)
            s_lower, next_rts = self.get_safety_and_red_best_response(game, next_bts)
            # Since usefulness does not depend on the red team strategy, we can reuse the value
            v_lower = (weight * s_lower) + (1 - weight) * u
            red_team_strategies.append(next_rts)
            delta = abs(v_upper - v_lower)

            t += 1
            if t > 100 or next_bts in blue_team_strategies[:-1]:
                err_str = f"δ={delta} > {self.equilibrium_tol}=τ\n" + \
                          "\n".join([f"{blue_team_strategies[i], red_team_strategies[i]}" for i in range(t)])
                if t > 100:
                    raise IterLimitExceeded(err_str)
                else:
                    raise LoopInMaxMin(err_str)

        return SingleObjectiveOptimiserOutput(
            weight=weight,
            blue_team_strategy=next_bts,
            red_team_strategy=next_rts,
            single_objective_value=v_lower,
            safety=s_lower,
            usefulness=u,
            info=dict()
        )


class FrontierExplorerWithIterativeMixedEquilibriumFinding(FrontierExplorer, ABC):
    def get_initial_blue_team_policy(self, game: Game) -> BlueTeamPolicy:
        raise NotImplementedError("Not needed for mixed equilibrium finding")

    def get_single_obj_equilibrium(self, game: Game, weight: float) -> SingleObjectiveOptimiserOutput:
        det_bts_list = []
        rts_list = [self.get_initial_red_team_strategy(game)]

        safeties = dict()
        uses = dict()

        delta = float("inf")

        t = 0

        # These are a bit innefficient but nice and readable
        def update_safeties_and_uses():
            for bts, rts in set(itertools.product(det_bts_list, rts_list)) - set(safeties.keys()):
                safeties[bts, rts] = self.get_safety(game, bts, rts)
            for bts in set(det_bts_list) - set(uses.keys()):
                uses[bts] = self.get_usefulness(game, bts)

        while delta > self.equilibrium_tol:
            # Step 1: Get the blue best response to the last red team strategy
            _, det_bts = self.get_value_and_blue_best_response(game, weight, rts_list[-1])
            det_bts_list.append(det_bts)

            # Step 2: Constraining blue to mixtures over det_bts_list and red to rts_list,
            #   find the leader-follower equilibrium which is optimistic for blue
            update_safeties_and_uses()  # Update the safeties and uses for the new strategies before passing on
            out = get_leader_follower_eq_from_s_u_and_weight(det_bts_list, rts_list, safeties, uses, weight)
            (pols, probs), best_existing_br, value = out
            mixed_blue = MixedBlueTeamPolicy(policies=pols, ps=probs)
            s_upper_for_mix = self.get_safety(game, mixed_blue, best_existing_br)
            u_for_mix = self.get_usefulness(game, mixed_blue)

            # Step 3: Keeping the blue mix fixed, find the best response for red across all strategies
            true_safety_of_mix, red_br_to_mix = self.get_safety_and_red_best_response(game, mixed_blue)
            rts_list.append(red_br_to_mix)

            # Calculate the optimistic and pessimistic values of the mixed strategy
            v_upper = (weight * s_upper_for_mix) + (1 - weight) * u_for_mix
            v_lower = (weight * true_safety_of_mix) + (1 - weight) * u_for_mix
            assert v_upper + self.equilibrium_tol >= v_lower, (v_upper - v_lower)
            delta = v_upper - v_lower
            t += 1

        return SingleObjectiveOptimiserOutput(
            weight=weight,
            blue_team_strategy=mixed_blue,
            red_team_strategy=red_br_to_mix,
            single_objective_value=v_lower,
            safety=true_safety_of_mix,
            usefulness=u_for_mix,
            info=dict()
        )


class FrontierExplorerWithDoubleOracle(FrontierExplorer, ABC):

    def get_single_obj_equilibrium(self, game: Game, weight: float, verbose=False) -> SingleObjectiveOptimiserOutput:

        # implements adaptation of Double Oracle algorithm
        bts_list = [self.get_initial_blue_team_policy(game)]
        rts_list = [self.get_initial_red_team_strategy(game)]

        safeties = dict()
        uses = dict()

        delta = float("inf")

        t = 0

        def update_safeties_and_uses():
            for bts, rts in set(itertools.product(bts_list, rts_list)) - set(safeties.keys()):
                safeties[bts, rts] = self.get_safety(game, bts, rts)
            for bts in set(bts_list) - set(uses.keys()):
                uses[bts] = self.get_usefulness(game, bts)

        bounds = []
        while delta > self.equilibrium_tol:
            if verbose:
                print(f"{delta:}>{self.equilibrium_tol}")
            # Step 1: find best Blue and Red mixed policies for subgame limited to discovered policies
            #   -> utilise `get_leader_follower_eq_from_s_u_and_weight` to get Nash Equilibrium
            #   -> need to make sure that safeties and uses dicts are up to date first.
            update_safeties_and_uses()

            blue_mix, red_mix = get_leader_follower_eq_from_s_u_and_weight(bts_list, rts_list, safeties, uses, weight,
                                                                           return_mixed=True)

            # Step 2: get best deterministic Blue response to Red mixed policy
            v_upper, pi_blue = self.get_value_and_blue_best_response(game, weight, red_mix)
            bts_list.append(pi_blue)

            # Step 3: get best deterministic Red response to Blue mixed policy
            blue_mix_safety, pi_red = self.get_safety_and_red_best_response(game, blue_mix)
            rts_list.append(pi_red)

            # Step 4: update bounds and increment t
            blue_mix_use = self.get_usefulness(game, blue_mix)
            v_lower = (weight * blue_mix_safety) + (1 - weight) * blue_mix_use
            bounds.append((v_lower, v_upper))
            delta = v_upper - v_lower
            t = t + 1

        blue_mix_use = self.get_usefulness(game, blue_mix)
        return SingleObjectiveOptimiserOutput(
            weight=weight,
            blue_team_strategy=blue_mix,
            red_team_strategy=pi_red,
            single_objective_value=v_lower,
            safety=blue_mix_safety,
            usefulness=blue_mix_use,
            info={"bounds": bounds, "equilibrium_tol": self.equilibrium_tol}
        )
