import copy
import logging
import numpy as np
import ray
from ray import cloudpickle as pickle
from collections import deque

from agents.league.match_func import fsp, pfsp


@ray.remote
class Coordinator:
    def __init__(
        self,
        match_func,
        win_rate_threshold,
        iter_threshold,
        newest_prob,
        max_league_size,
        seed,
        debug=False,
    ):
        self.match_func = match_func
        self.win_rate_threshold = win_rate_threshold
        self.iter_threshold = iter_threshold
        self.newest_prob = newest_prob
        self.max_league_size = max_league_size
        self._seed = seed
        np.random.seed(self._seed)

        self._burn_in_policies = {"main_v0"}
        self.historical_policies = {"main_v0"}
        self.version = 0

        self._wins = {}
        self._draws = {}
        self._losses = {}
        self._games = {}
        self._decay = 0.99

        self._game_results = {}
        self._win_rate_history = {}
        self._burn_in_period = 20

        self.debug = debug
        if self.debug:
            logging.basicConfig(
                level=logging.INFO,
                format="%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s",
                handlers=[logging.FileHandler(f"ray_results/{match_func}_coordinator.log")],
            )

    def update(self, away, result):
        for stats in (self._wins, self._draws, self._losses, self._games):
            if away in stats:
                stats[away] *= self._decay
            else:
                stats[away] = 0
        self._games[away] += 1
        if result == "win":
            self._wins[away] += 1
        elif result == "draw":
            self._draws[away] += 1
        else:
            self._losses[away] += 1

        if away not in self._game_results:
            self._game_results[away] = deque(maxlen=self._burn_in_period)
        if result == "win":
            self._game_results[away].append(1)
        elif result == "draw":
            self._game_results[away].append(0.5)
        else:
            self._game_results[away].append(0)

        if away in self._burn_in_policies:
            if len(self._game_results[away]) >= self._burn_in_period:
                if self.debug:
                    logging.info(f"Remove {away} from burn-in period ({self._burn_in_policies})")
                self._burn_in_policies.remove(away)
        else:
            if away in self._win_rate_history:
                self._win_rate_history[away].append(self.get_win_rates()[away])
            else:
                self._win_rate_history[away] = deque([self.get_win_rates()[away]], maxlen=self._burn_in_period)

    def get_win_rates(self):
        return {
            k: (self._wins[k] + 0.5 * self._draws[k]) / self._games[k]
            if self._games[k] != 0
            else 0.5
            for k in self._wins.keys() if k not in self._burn_in_policies
        }
        # return {k: sum(v) / len(v) for k, v in self._game_results.items() if k not in self._burn_in_policies}

    def clear(self):
        self._wins.clear()
        self._draws.clear()
        self._losses.clear()
        self._games.clear()
        self._game_results.clear()
        self._win_rate_history.clear()

    def add_policy(self):
        self.version += 1
        new_pol_id = f"main_v{self.version}"
        self.historical_policies.add(new_pol_id)
        self._burn_in_policies.add(new_pol_id)
        if self.debug:
            logging.info(f"Snapshot and add new policy: {new_pol_id}")
        return new_pol_id

    def league_version(self):
        return self.version

    def get_new_policy_mapping_fn(self, policies=None):
        if self.debug:
            logging.info(f"Sampling new policy mapping function:")
        if policies is None:
            historical_team = sorted(
                self.historical_policies,
                key=lambda x: int(x.strip("main_v"))
            )
        else:
            historical_team = sorted(
                [p for p in policies if p.startswith("main_v")],
                key=lambda x: int(x.strip("main_v"))
            )

        if self.match_func == "FSP":
            probs = fsp(
                num_models=len(historical_team), newest_prob=self.newest_prob
            )
        elif self.match_func == "PFSP":
            win_rate_dict = self.get_win_rates()
            win_rates = []
            for p in historical_team:
                if p not in self._burn_in_policies:
                    win_rates.append(win_rate_dict[p])
                else:
                    win_rates.append(0.5)
            probs = pfsp(win_rates=win_rates, weighting="variance")
        elif self.match_func == "ALP":
            alp = []
            for p in historical_team:
                if p in self._burn_in_policies:
                    alp.append(1)
                elif p in self._win_rate_history:
                    alp.append(
                        abs(sum(list(self._win_rate_history[p])[:5]) -
                            sum(list(self._win_rate_history[p])[-5:])) / 5
                    )
                else:
                    alp.append(1)
            probs = pfsp(win_rates=alp, weighting="alp")
        opponent = np.random.choice(historical_team, p=probs)

        if np.random.uniform() < 0.5:
            left = "main"
            right = opponent
        else:
            left = opponent
            right = "main"

        if self.debug:
            logging.info(f"\tSampled new policy mapping: {left} vs {right}")

        def fn(agent_id, episode, worker, **kwargs):
            if agent_id == 0:
                return left
            elif agent_id == 1:
                return right
            else:
                raise ValueError(f"Wrong agent id: {agent_id}")

        return fn

    def save(self) -> bytes:
        """Serializes this Coordinator's current state and returns it.

        Returns:
            The current state of this coordinator as a serialized, pickled
            byte sequence.
        """
        return pickle.dumps(
            {
                "historical_policies": self.historical_policies,
                "burn_in_policies": self._burn_in_policies,
                "version": self.version,
                "game_results": self._game_results,
                "win_rate_history": self._win_rate_history,
                "wins": self._wins,
                "draws": self._draws,
                "losses": self._losses,
                "games": self._games,
            }
        )

    def restore(self, objs: bytes) -> None:
        """Restores this Coordinator's state from a sequence of bytes.

        Args:
            objs: The byte sequence to restore this coordinator's state from.
        """
        objs = pickle.loads(objs)
        self.historical_policies = objs["historical_policies"]
        self._burn_in_policies = objs["burn_in_policies"]
        self.version = objs["version"]
        self._game_results = objs["game_results"]
        self._win_rate_history = objs["win_rate_history"]
        self._wins = objs["wins"]
        self._draws = objs["draws"]
        self._losses = objs["losses"]
        self._games = objs["games"]


@ray.remote
class AsymmetricCoordinator:
    def __init__(
        self,
        match_func,
        win_rate_threshold,
        iter_threshold,
        newest_prob,
        max_league_size,
        seed,
        debug=False,
    ):
        self.match_func = match_func
        self.win_rate_threshold = win_rate_threshold
        self.iter_threshold = iter_threshold
        self.newest_prob = newest_prob
        self.max_league_size = max_league_size
        self._seed = seed
        np.random.seed(self._seed)

        self.left_team_policies = {"left_v0"}
        self.right_team_policies = {"right_v0"}
        # self._burn_in_policies = {"left_v0", "right_v0"}

        self.left_version = 0
        self.right_version = 0

        self._wins = {"main_left": {}, "main_right": {}}
        self._draws = {"main_left": {}, "main_right": {}}
        self._losses = {"main_left": {}, "main_right": {}}
        self._games = {"main_left": {}, "main_right": {}}
        self._decay = 0.99

        self._game_results = {"main_left": {}, "main_right": {}}
        self._win_rate_history = {"main_left": {}, "main_right": {}}
        self._burn_in_period = 20
        self._steps = 0

        self.training_ratio = [0.5, 0.5]
        self.policy_to_remove = []

        self.debug = debug
        if self.debug:
            logging.basicConfig(
                level=logging.INFO,
                format="%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s",
                handlers=[logging.FileHandler(f"ray_results/{match_func}_coordinator.log")],
            )

    def update(self, home, away, result):
        """Update game statistics. called on episode end."""
        # Decay.
        for stats in (self._wins, self._draws, self._losses, self._games):
            assert home in stats, home
            if away in stats[home]:
                stats[home][away] *= self._decay
            else:
                stats[home][away] = 0

        # Update.
        self._games[home][away] += 1
        if away not in self._game_results[home]:
            self._game_results[home][away] = deque(maxlen=self._burn_in_period)
        if result == "win":
            self._wins[home][away] += 1
            self._game_results[home][away].append(1)
        elif result == "draw":
            self._draws[home][away] += 1
            self._game_results[home][away].append(0.5)
        else:
            self._losses[home][away] += 1
            self._game_results[home][away].append(0)

        # Burn-in policies.
        # if away in self._burn_in_policies:
        #     if len(self._game_results[home][away]) >= self._burn_in_period:
        #         if self.debug:
        #             logging.info(f"Remove {away} from burn-in period ({self._burn_in_policies})")
        #         self._burn_in_policies.remove(away)
        # else:
        if away in self._win_rate_history[home]:
            self._win_rate_history[home][away].append(self.get_win_rates(home)[away])
        else:
            self._win_rate_history[home][away] = deque([self.get_win_rates(home)[away]], maxlen=self._burn_in_period)

    def add_policy(self, team_id):
        """Add a new policy if the league satisfies snapshot condition."""
        if team_id == "left":
            self.left_version += 1
            new_pol_id = f"left_v{self.left_version}"
            self.left_team_policies.add(new_pol_id)

            if self.left_version >= self.max_league_size:
                remove_pol_id = sorted(
                    self.left_team_policies,
                    key=lambda x: int(x.strip("left_v"))
                )[0]
                self.policy_to_remove.append(remove_pol_id)
                self.left_team_policies.remove(remove_pol_id)
                for stats in (
                    self._wins,
                    self._draws,
                    self._losses,
                    self._games,
                    self._game_results,
                    self._win_rate_history,
                ):
                    stats["main_right"].pop(remove_pol_id)

        elif team_id == "right":
            self.right_version += 1
            new_pol_id = f"right_v{self.right_version}"
            self.right_team_policies.add(new_pol_id)

            if self.right_version >= self.max_league_size:
                remove_pol_id = sorted(
                    self.right_team_policies,
                    key=lambda x: int(x.strip("right_v"))
                )[0]
                self.policy_to_remove.append(remove_pol_id)
                self.right_team_policies.remove(remove_pol_id)
                for stats in (
                    self._wins,
                    self._draws,
                    self._losses,
                    self._games,
                    self._game_results,
                    self._win_rate_history,
                ):
                    stats["main_left"].pop(remove_pol_id)

        else:
            raise ValueError(f"Wrong team id: {team_id}")
        # self._burn_in_policies.add(new_pol_id)
        if self.debug:
            logging.info(f"Snapshot main_{team_id} policy and add new policy: {new_pol_id}")
        return new_pol_id

    def get_policy_to_remove(self):
        if len(self.policy_to_remove) > 0:
            policy_to_remove = self.policy_to_remove.pop()
        else:
            policy_to_remove = None
        return policy_to_remove

    def get_new_policy_mapping_fn(self, policies=None, worker_index=0, env_index=0):
        """Update the mapping function.

        The "main" plays against previous policies uniformly.
        """
        if self.debug:
            logging.info(f"Sampling new policy mapping function:")

        if policies is None:
            left_team_policies = self.left_team_policies
            right_team_policies = self.right_team_policies
        else:
            left_team_policies = [p for p in policies if p.startswith("left") and p in self.left_team_policies]
            right_team_policies = [p for p in policies if p.startswith("right") and p in self.right_team_policies]

        historical_left_team = sorted(
            left_team_policies,
            key=lambda x: int(x.strip("left_v")))
        historical_right_team = sorted(
            right_team_policies,
            key=lambda x: int(x.strip("right_v"))
        )

        # print("=====historical_left_team:", historical_left_team, historical_right_team)

        home = np.random.choice(["main_left", "main_right"], p=self.training_ratio)
        if home == "main_left":
            away_candidates = historical_right_team
        else:
            away_candidates = historical_left_team

        # print("========away_candidates========:", away_candidates, self.policy_to_remove)

        if self.match_func == "FSP":
            probs = fsp(
                num_models=len(away_candidates), newest_prob=self.newest_prob
            )
            away = np.random.choice(away_candidates, p=probs)
            if self.debug:
                logging.info(f"\tHeuristic FSP: {home} vs {away_candidates}\n"
                             f"\tSampled with probability {probs}")

        elif self.match_func == "PFSP":
            win_rate_dict = self.get_win_rates(home)
            win_rates = []
            for p in away_candidates:
                if p in win_rate_dict:
                    win_rates.append(win_rate_dict[p])
                else:
                    win_rates.append(0.5)
            probs = pfsp(win_rates=win_rates, weighting="variance")
            away = np.random.choice(away_candidates, p=probs)

        elif self.match_func == "ALP":
            alp = []
            for p in away_candidates:
                if len(self._win_rate_history[home]) < self._burn_in_period:
                    alp.append(1)
                elif p in self._win_rate_history[home]:
                    alp.append(
                        abs(sum(list(self._win_rate_history[home][p])[:5]) -
                            sum(list(self._win_rate_history[home][p])[-5:])) / 5
                    )
                else:
                    alp.append(1)
            probs = pfsp(win_rates=alp, weighting="alp")
            away = np.random.choice(away_candidates, p=probs)
            if self.debug:
                logging.info(f"\tAbsolute Learning Progress PFSP: {home} vs {away_candidates}\n"
                             f"\tSampled with probability {probs}")

        if self.debug:
            logging.info(f"\tSampled new policy mapping: {home} vs {away}")

        def fn(agent_id, episode, worker, **kwargs):
            if agent_id.startswith("left"):
                return home if "left" in home else away
            elif agent_id.startswith("right"):
                return home if "right" in home else away
            else:
                raise ValueError(f"Wrong agent id: {agent_id}")

        return fn

    def update_training_ratio(self):
        # Unbalanced training proportion of each side.
        left_win_rates = self.get_win_rates("main_left")
        right_win_rates = self.get_win_rates("main_right")
        if len(left_win_rates) > 0 and len(right_win_rates) > 0:
            left_win_mean = float(np.mean(list(left_win_rates.values())))
            right_win_mean = float(np.mean(list(right_win_rates.values())))
            self.training_ratio = pfsp(win_rates=[left_win_mean, right_win_mean], weighting="linear_capped")

        return self.training_ratio

    def get_win_rates(self, home):
        return {
            k: (self._wins[home][k] + 0.5 * self._draws[home][k]) / self._games[home][k]
            if self._games[home][k] != 0
            else 0.5
            for k in self._wins[home].keys()
        }
        # return {k: sum(v) / len(v) for k, v in self._game_results[home].items() if k not in self._burn_in_policies}

    def clear(self):
        self._wins.clear()
        self._draws.clear()
        self._losses.clear()
        self._games.clear()
        self._game_results.clear()
        self._win_rate_history.clear()

    def league_version(self):
        return self.left_version, self.right_version

    def save(self) -> bytes:
        """Serializes this Coordinator's current state and returns it.

        Returns:
            The current state of this coordinator as a serialized, pickled
            byte sequence.
        """
        return pickle.dumps(
            {
                "left_team_policies": self.left_team_policies,
                "right_team_policies": self.right_team_policies,
                # "burn_in_policies": self._burn_in_policies,
                "left_version": self.left_version,
                "right_version": self.right_version,
                "game_results": self._game_results,
                "win_rate_history": self._win_rate_history,
                "training_ratio": self.training_ratio,
                "wins": self._wins,
                "draws": self._draws,
                "losses": self._losses,
                "games": self._games,
            }
        )

    def restore(self, objs: bytes) -> None:
        """Restores this Coordinator's state from a sequence of bytes.

        Args:
            objs: The byte sequence to restore this coordinator's state from.
        """
        objs = pickle.loads(objs)
        self.left_team_policies = objs["left_team_policies"]
        self.right_team_policies = objs["right_team_policies"]
        # self._burn_in_policies = objs["burn_in_policies"]
        self.left_version = objs["left_version"]
        self.right_version = objs["right_version"]
        self._game_results = objs["game_results"]
        self._win_rate_history = objs["win_rate_history"]
        self.training_ratio = objs["training_ratio"]
        self._wins = objs["wins"]
        self._draws = objs["draws"]
        self._losses = objs["losses"]
        self._games = objs["games"]


@ray.remote
class PopulationCoordinator:
    def __init__(
        self,
        match_func,
        population_size,
        seed,
        debug=False,
    ):
        self.match_func = match_func
        self.population_size = population_size
        self._seed = seed
        np.random.seed(self._seed)

        self.left_team_policies = {f"left_{i}" for i in range(self.population_size)}
        self.right_team_policies = {f"right_{i}" for i in range(self.population_size)}
        # self._burn_in_policies = {"left_v0", "right_v0"}

        self._wins = dict(**{f"left_{i}": {} for i in range(self.population_size)},
                          **{f"right_{i}": {} for i in range(self.population_size)})
        self._draws = copy.deepcopy(self._wins)
        self._losses = copy.deepcopy(self._wins)
        self._games = copy.deepcopy(self._wins)
        self._decay = 0.99

        self._game_results = copy.deepcopy(self._wins)
        self._win_rate_history = copy.deepcopy(self._wins)
        self._burn_in_period = 20

        self.training_ratio = [0.5, 0.5]
        self.left_training_ratio = None
        self.right_training_ratio = None

        self.debug = debug
        if self.debug:
            logging.basicConfig(
                level=logging.INFO,
                format="%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s",
                handlers=[logging.FileHandler(f"ray_results/{match_func}_coordinator.log")],
            )

    def update(self, home, away, result):
        """Update game statistics. called on episode end."""
        # Decay.
        for stats in (self._wins, self._draws, self._losses, self._games):
            assert home in stats, home
            if away in stats[home]:
                stats[home][away] *= self._decay
            else:
                stats[home][away] = 0

        # Update.
        self._games[home][away] += 1
        if away not in self._game_results[home]:
            self._game_results[home][away] = deque(maxlen=self._burn_in_period)
        if result == "win":
            self._wins[home][away] += 1
            self._game_results[home][away].append(1)
        elif result == "draw":
            self._draws[home][away] += 1
            self._game_results[home][away].append(0.5)
        else:
            self._losses[home][away] += 1
            self._game_results[home][away].append(0)

        # Burn-in policies.
        # if away in self._burn_in_policies:
        #     if len(self._game_results[home][away]) >= self._burn_in_period:
        #         if self.debug:
        #             logging.info(f"Remove {away} from burn-in period ({self._burn_in_policies})")
        #         self._burn_in_policies.remove(away)
        # else:
        if away in self._win_rate_history[home]:
            self._win_rate_history[home][away].append(self.get_win_rates(home)[away])
        else:
            self._win_rate_history[home][away] = deque([self.get_win_rates(home)[away]],
                                                       maxlen=self._burn_in_period)

    def get_policy_to_remove(self):
        if len(self.policy_to_remove) > 0:
            policy_to_remove = self.policy_to_remove.pop()
        else:
            policy_to_remove = None
        return policy_to_remove

    def get_new_policy_mapping_fn(self, policies=None):
        """Update the mapping function.

        The "main" plays against previous policies uniformly.
        """
        if self.debug:
            logging.info(f"Sampling new policy mapping function:")

        if policies is None:
            left_team_policies = self.left_team_policies
            right_team_policies = self.right_team_policies
        else:
            left_team_policies = [p for p in policies if p.startswith("left")]
            right_team_policies = [p for p in policies if p.startswith("right")]

        historical_left_team = sorted(
            left_team_policies,
            key=lambda x: int(x.strip("left_")))
        historical_right_team = sorted(
            right_team_policies,
            key=lambda x: int(x.strip("right_"))
        )

        home_team = np.random.choice(["left", "right"], p=self.training_ratio)
        if home_team == "left":
            home = np.random.choice([f"left_{i}" for i in range(self.population_size)], p=self.left_training_ratio)
            away_candidates = historical_right_team
        else:
            home = np.random.choice([f"right_{i}" for i in range(self.population_size)], p=self.right_training_ratio)
            away_candidates = historical_left_team

        away = np.random.choice(away_candidates)

        if self.debug:
            logging.info(f"\tSampled new policy mapping: {home} vs {away}")

        def fn(agent_id, episode, worker, **kwargs):
            if agent_id.startswith("left"):
                return home if "left" in home else away
            elif agent_id.startswith("right"):
                return home if "right" in home else away
            else:
                raise ValueError(f"Wrong agent id: {agent_id}")

        return fn

    def update_training_ratio(self):
        # Unbalanced training proportion of each side.
        left_win_rates, right_win_rates = [], []
        for i in range(self.population_size):
            left_win_rates_dict = self.get_win_rates(f"left_{i}")
            right_win_rates_dict = self.get_win_rates(f"right_{i}")
            if len(left_win_rates_dict) > 0 and len(right_win_rates_dict) > 0:
                left_win_rates.append(float(np.mean([left_win_rates_dict[f"right_{i}"] for i in range(self.population_size)])))
                right_win_rates.append(float(np.mean([right_win_rates_dict[f"left_{i}"] for i in range(self.population_size)])))
        self.left_training_ratio = pfsp(win_rates=left_win_rates, weighting="linear_capped")
        self.right_training_ratio = pfsp(win_rates=right_win_rates, weighting="linear_capped")

        return self.training_ratio, self.left_training_ratio, self.right_training_ratio

    def get_win_rates(self, home):
        return {
            k: (self._wins[home][k] + 0.5 * self._draws[home][k]) / self._games[home][k]
            if self._games[home][k] != 0
            else 0.5
            for k in self._wins[home].keys()
        }
        # return {k: sum(v) / len(v) for k, v in self._game_results[home].items() if k not in self._burn_in_policies}

    def clear(self):
        self._wins.clear()
        self._draws.clear()
        self._losses.clear()
        self._games.clear()
        self._game_results.clear()
        self._win_rate_history.clear()

    def save(self) -> bytes:
        return pickle.dumps(
            {
                "left_team_policies": self.left_team_policies,
                "right_team_policies": self.right_team_policies,
                # "burn_in_policies": self._burn_in_policies,
                # "left_version": self.left_version,
                # "right_version": self.right_version,
                "game_results": self._game_results,
                "win_rate_history": self._win_rate_history,
                "training_ratio": self.training_ratio,
                "left_training_ratio": self.left_training_ratio,
                "right_training_ratio": self.right_training_ratio,
                "wins": self._wins,
                "draws": self._draws,
                "losses": self._losses,
                "games": self._games,
            }
        )

    def restore(self, objs: bytes) -> None:
        objs = pickle.loads(objs)
        self.left_team_policies = objs["left_team_policies"]
        self.right_team_policies = objs["right_team_policies"]
        # self._burn_in_policies = objs["burn_in_policies"]
        # self.left_version = objs["left_version"]
        # self.right_version = objs["right_version"]
        self._game_results = objs["game_results"]
        self._win_rate_history = objs["win_rate_history"]
        self.training_ratio = objs["training_ratio"]
        self.left_training_ratio = objs["left_training_ratio"]
        self.right_training_ratio = objs["right_training_ratio"]
        self._wins = objs["wins"]
        self._draws = objs["draws"]
        self._losses = objs["losses"]
        self._games = objs["games"]
