import json
import os
import shutil
import typing
from multiprocessing import Pool
import numpy as np
from tqdm import tqdm

from core.algorithm import ConstrainedMDP
from simulations.simulation_environment import (
    DataEnhancedEnvironment,
)
from simulations.utils import (
    plot_cumulative,
    plot_lagrangian,
    plot_no_constraints,
    plot_reward,
)

if typing.TYPE_CHECKING:
    from typing import Any, List, Tuple, Dict

class SimulationConstrainedMDP(ConstrainedMDP):
    """Class to run the PD-DP algorithm in a simulated environment"""

    def __init__(
        self,
        save=True,
        sim_path="../config/sim_config.json",
        data=True,
    ):
        super().__init__()

        # Read config file
        self.confidence_delta: float = None
        self.n_iter: int = None
        self.temporal_horizon: int = None
        self.constraints_difficulty: float = None
        self.adversarial: bool = None
        self._read_config(sim_path)
        self.save = save
        self.data = data
        self.sim_path = sim_path

        self.environment = DataEnhancedEnvironment(
            time_horizon=self.temporal_horizon,
            constraints_difficulty=self.constraints_difficulty,
            n_shifts=self.n_shifts,
            path=self.env_path,
        )

        # Create folder name to save results
        self.save_name = (
            "output/"
            + ("data_" if self.data else "sim_")
            + (self.prefix + "_" if self.prefix != "" else "")
            + str(self.environment.n_actions)
            + "actions"
            + str(self.temporal_horizon)
            + "t_"
            + str(self.n_iter)
            + "iter_"
            + str(self.environment.n_constraints)
            + "constraints_"
            + str(self.constraints_difficulty)
            + "difficulty_"
            + str(self.n_batch)
            + "nbatch_"
            + str(self.mean_update)
            + "mean_update_"
            + str(self.n_shifts)
            + "nshifts"
            + "/"
        )
        print(self.save_name)

        # Initialize arrays to store results
        self.y_vals_regret = np.empty((self.n_iter, self.temporal_horizon), dtype=float)
        self.y_vals_constraints = np.empty(
            (self.n_iter, self.temporal_horizon), dtype=float
        )
        self.y_vals_constraints_variation = np.empty(
            (self.n_iter, self.temporal_horizon), dtype=float
        )
        self.y_vals_lagrangian = np.empty(
            (self.n_iter, self.temporal_horizon, self.environment.n_constraints),
            dtype=float,
        )
        self.y_vals_rewards = np.empty(
            (self.n_iter, self.temporal_horizon), dtype=float
        )

        # Initialize arrays to store intermediate results
        self.occ_measure_struct = None
        self.lagrangian_struct = None
        self.optimal_q_struct = None
        self.policy_struct = None
        self.loss_struct = None

        # Create folder to save results
        if save:
            if os.path.exists(self.save_name):
                shutil.rmtree(self.save_name)
            os.mkdir(self.save_name)
            print("created folder " + self.save_name)

            # Save configuration file
            shutil.copy(sim_path, self.save_name)
            shutil.copy(self.env_path, self.save_name)

    def _read_config(self, path):
        with open(path, "r", encoding="utf-8") as f:
            variables = json.load(f)
            self.confidence_delta = variables["confidence_delta"]
            self.n_iter = variables["n_iter"]
            self.temporal_horizon = variables["temporal_horizon"]
            self.adversarial = variables["adversarial"]
            self.constraints_difficulty = variables["constraints_difficulty"]
            self.n_batch = variables["n_batch"]
            self.n_shifts = variables["n_shifts"]
            self.mean_update = variables["mean_update"]
            self.env_path = variables["env_path"]
            self.prefix = variables["prefix"]

    def cumulative_regret(self, adversarial=False):
        """
        Calculates the cumulative regret for the algorithm
        """
        print("Calculating cumulative regret")
        y = np.empty(self.temporal_horizon, dtype=float)
        played = np.einsum(
            "ij, ij->i",
            self.occ_measure_struct,
            self.environment.rewards_list,
        )

        opt = 0
        for t in range(self.temporal_horizon):
            opt += np.dot(
                self.environment.get_optimal_q_pairs(t),
                self.environment.get_reward_mean(t),
            )
            y[t] = opt - np.sum(played[0 : t + 1])
        return y

    def cumulative_mean_reward(self):
        """
        Calculates the cumulative mean reward for the algorithm
        """
        print("Calculating cumulative reward")
        y = np.empty(self.temporal_horizon, dtype=float)
        played = np.einsum(
            "ij, ij->i",
            self.policy_struct,
            self.environment.rewards_list,
        )

        for t in range(self.temporal_horizon):
            y[t] = np.mean(played[0 : t + 1])
        return y

    def cumulative_constraints_violation(self):
        """
        Calculates the cumulative constraints violation for the algorithm
        """
        print("Calculating cumulative constraints violation")
        y = np.empty(self.temporal_horizon, dtype=float)
        violations = np.empty(
            (self.temporal_horizon, self.environment.n_constraints), dtype=float
        )
        for t in range(self.temporal_horizon):
            violations[t] = self.environment.constraints_list[t].T.dot(
                self.occ_measure_struct[t]
            )

        violations = np.cumsum(violations, axis=0)
        max_index = np.argmax(violations[-1])
        for t in range(self.temporal_horizon):
            y[t] = violations[t, max_index]
        return y

    def _from_policy_tf_to_om(self, policy):
        """
        Given a policy, returns the corresponding occupancy measure (in pairs format)
        induced by the True Transition Function
        """
        # Initialize occupancy measure
        om = {}

        # Unravels the policy to get the occupancy measure
        for triple in self.environment.sas_ind:
            # If the state is the initial state, the occupancy measure in (x,a,x')
            # is simply the policy in (x,a) times the true transition function in (x,a,x')
            if triple[0] == self.environment.layers[0][0]:
                om[triple] = (
                    policy[triple[0:2]] * self.environment.transition_function[triple]
                )

            # Otherwise, the occupancy measure in (x,a,x') is the policy in (x,a) times
            # the true transition function in (x,a,x') times the sum of the occupancy
            # measures that lead into x, i.e., all q(x*,a*,x) for all (x*,a*)
            else:
                om[triple] = (
                    policy[triple[0:2]]
                    * self.environment.transition_function[triple]
                    * sum([om[trio] for trio in om if trio[2] == triple[0]])
                )

        # Go from triple (x,a,x') to pair (x,a) format
        return list(
            {
                key: sum([om[triple] for triple in om if triple[0:2] == key])
                for key in self.environment.sa_ind
            }.values()
        )

    def run_single_iteration(self, seed):
        self.environment.reset(time_horizon=self.temporal_horizon, seed=seed)

        # Clean algorithm
        self.instantiate_algorithms(self.environment, self.sim_path)
        super().__init__()

        # occ_measure struct
        self.occ_measure_struct = np.empty(
            (self.temporal_horizon, self.environment.sa_count[-1]), dtype=float
        )
        self.lagrangian_struct = np.empty(
            (self.temporal_horizon, self.environment.n_constraints), dtype=float
        )
        self.loss_struct = np.empty(
            (self.temporal_horizon, 3, self.environment.sa_count[-1]), dtype=float
        )  # There are 3 components in the loss: the overall loss, the loss due to rewards and the loss of the constraints

        self.optimal_q_struct = np.empty(
            (self.temporal_horizon, self.environment.sa_count[-1]), dtype=float
        )

        self.policy_struct = np.empty(
            (self.temporal_horizon, self.environment.sa_count[-1]), dtype=float
        )

        np.random.seed(seed)

        # Run algorithm
        for t in tqdm(range(0, self.temporal_horizon)):
            policy, lagrangian = self.round_play(t)

            self.policy_struct[t] = list(policy.values())
            self.occ_measure_struct[t] = self._from_policy_tf_to_om(policy)
            self.lagrangian_struct[t] = lagrangian
            self.optimal_q_struct[t] = self.environment.get_optimal_q_pairs(t)

            # Performs update handling batch size
            if ((t + 1) % self.n_batch == 0) or (t == self.temporal_horizon - 1):
                for index, (
                    round_policy,
                    round_lagrangian,
                    round_loss,
                    counter,
                ) in enumerate(self.round_update(self.mean_update)):
                    if ((t + 1) % self.n_batch) != 0:
                        modulus = (t + 1) % self.n_batch
                    else:
                        modulus = self.n_batch
                    t_ind = t - modulus + index + 1
                    self.loss_struct[t_ind] = round_loss

        print("Calculating regret and constraints violation...")
        y_regret = self.cumulative_regret()
        y_reward = self.cumulative_mean_reward()
        if self.environment.n_constraints > 0:
            y_constraints = self.cumulative_constraints_violation()
        else:
            y_constraints = None

        print("Saving results...")
        if self.save:
            # Save final estimated transition function in json format
            # converting dictionary keys to strings
            with open(
                self.save_name + "final_transition.json", "w", encoding="utf-8"
            ) as f:
                json.dump(
                    {
                        str(key): value
                        for key, value in self.algorithm_mdp.transition_function.empiric_transition_function.items()
                    },
                    f,
                    indent=4,
                )

        # Return all necessary results
        return {
            "y_regret": y_regret,
            "y_reward": y_reward,
            "y_constraints": y_constraints,
            "y_lagrangian": self.lagrangian_struct,
            "optimal_q_struct": self.optimal_q_struct,
            "policy_struct": self.policy_struct,
            "occ_measure_struct": self.occ_measure_struct,
            "loss_struct": self.loss_struct,
        }

    def aggregate_results(self, results):
        """ Helper function to aggregate results from parallel runs """
        # Initialize aggregated results
        aggregated_y_regret = []
        aggregated_y_reward = []
        aggregated_y_constraints = []
        aggregated_y_lagrangian = []
        aggregated_optimal_q_struct = []
        aggregated_policy_struct = []
        aggregated_occ_measure_struct = []
        aggregated_loss_struct = []

        # Aggregate results from each iteration
        for result in results:
            aggregated_y_regret.append(result["y_regret"])
            aggregated_y_reward.append(result["y_reward"])
            aggregated_y_lagrangian.append(result["y_lagrangian"])
            aggregated_optimal_q_struct.append(result["optimal_q_struct"])
            aggregated_policy_struct.append(result["policy_struct"])
            aggregated_occ_measure_struct.append(result["occ_measure_struct"])
            aggregated_loss_struct.append(result["loss_struct"])
            if result["y_constraints"] is not None:
                aggregated_y_constraints.append(result["y_constraints"])

        if self.save:
            # Save aggregated results
            with open(self.save_name + "y_vals_regret" + ".npy", "ab+") as f:
                np.save(f, aggregated_y_regret)

            with open(self.save_name + "y_vals_rewards" + ".npy", "ab+") as f:
                np.save(f, aggregated_y_reward)

            if self.environment.n_constraints > 0:
                with open(self.save_name + "y_vals_constraints" + ".npy", "ab+") as f:
                    np.save(f, aggregated_y_constraints)

            with open(self.save_name + "occ_measure_struct" + ".npy", "ab+") as f:
                np.save(f, aggregated_occ_measure_struct)

            with open(self.save_name + "policy_struct" + ".npy", "ab+") as f:
                np.save(f, aggregated_policy_struct)

            with open(self.save_name + "lagrangian_struct" + ".npy", "ab+") as f:
                np.save(f, aggregated_y_lagrangian)

            with open(self.save_name + "loss_struct" + ".npy", "ab+") as f:
                np.save(f, aggregated_loss_struct)

            with open(self.save_name + "optimal_q_struct" + ".npy", "ab+") as f:
                np.save(f, aggregated_optimal_q_struct)

        return {
            "y_regret": aggregated_y_regret,
            "y_reward": aggregated_y_reward,
            "y_constraints": aggregated_y_constraints,
            "y_lagrangian": aggregated_y_lagrangian,
            "optimal_q_struct": aggregated_optimal_q_struct,
            "policy_struct": aggregated_policy_struct,
            "occ_measure_struct": aggregated_occ_measure_struct,
            "loss_struct": aggregated_loss_struct,
        }

    def run_parallel(self):
        """ Run the algorithm in parallel """
        # Generate a unique seed for each iteration
        seeds = [27091999 + i for i in range(self.n_iter)]

        # Run iterations in parallel
        with Pool() as pool:
            results = pool.map(self.run_single_iteration, seeds)

        # Aggregate results
        aggregated_results = self.aggregate_results(results)

        # Save, plot, and return the results
        print("Plotting...")
        plot_cumulative(
            aggregated_results["y_regret"],
            aggregated_results["y_constraints"],
            self.save,
            self.save_name,
            self.environment.n_constraints,
        )

        plot_reward(aggregated_results["y_reward"], self.save, self.save_name)

        if self.environment.n_constraints > 0:
            plot_lagrangian(
                aggregated_results["y_lagrangian"],
                self.environment.n_constraints,
                self.save,
                self.save_name,
            )
        else:
            plot_no_constraints(self.save, self.save_name)

        return aggregated_results
