import os
from collections import deque
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union

import numpy as np  # type:ignore
import pandas as pd
import torch as th  # type:ignore
from torch.distributions import Categorical  # type:ignore

from common import (
    GRUPolicyNetwork,
    PolicyNetwork,
    RewardConditionedGRUPolicyNetwork,
    RewardConditionedPolicyNetwork,
    RLAlgorithm,
    StateValueNetwork,
    WandbLogger,
    min_proportion, 
    max_min_proportion
)


class EUPG(RLAlgorithm):
    index = 0

    def __init__(
        self,
        reward_dim: int,
        env: Any,
        obs_shape: np.ndarray,
        hidden_layers: List[int],
        n_actions: int,
        scalarization: Callable[[th.Tensor, th.Tensor], th.Tensor],
        scalarization_weights: Union[List[float], th.Tensor, np.ndarray],
        device: str = "cpu",
        gamma: float = 0.999,
        lr: float = 2e-4,
        standardization: bool = True,
        use_baseline: bool = False,
        recurrent_policy: bool = False,
        reward_conditioned: bool = True,
    ) -> None:
        super().__init__()
        self.env = env
        self.reward_dim = reward_dim
        self.scalarization = scalarization
        self.scalarization_weights = scalarization_weights
        self.obs_shape = obs_shape
        self.n_action = n_actions
        self.device = device
        self.curr_step = 0
        self.gamma = gamma
        self.lr = lr
        self.on_policy_trajectory_buffer: Iterable[
            Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
        ] = deque([])
        # To use when adding experience sharing
        self.off_policy_trajectories_buffer = deque([])
        self.best_policy: Any = None
        self.best_score: float = -np.infty
        self.reward_conditioned = reward_conditioned
        self.recurrent_policy = recurrent_policy
        if reward_conditioned:
            if recurrent_policy:
                self.policy_net: th.nn.Module = (
                    RewardConditionedGRUPolicyNetwork(
                        obs_shape, self.reward_dim, 32, n_actions
                    ).to(device)
                )

            else:
                self.policy_net: th.nn.Module = RewardConditionedPolicyNetwork(
                    obs_shape, self.reward_dim, hidden_layers, n_actions
                ).to(self.device)

        else:
            if recurrent_policy:
                self.policy_net: th.nn.Module = GRUPolicyNetwork(
                    obs_shape, 32, n_actions
                ).to(device)

            else:
                self.policy_net: th.nn.Module = PolicyNetwork(
                    obs_shape, hidden_layers, n_actions
                ).to(self.device)

        self.policy_optimizer: Any = th.optim.Adam(
            self.policy_net.parameters(), lr=self.lr
        )
        self.use_baseline: bool = use_baseline
        self.standardization: bool = standardization

        if self.use_baseline:
            self.baseline_net = StateValueNetwork(
                obs_shape, self.reward_dim, hidden_layers
            ).to(self.device)
            self.baseline_optimizer = th.optim.Adam(
                self.baseline_net.parameters(), lr=self.lr
            )

    def act(
        self,
        state: th.Tensor,
        accrued_reward: th.Tensor,
        prev_action: th.Tensor = None, 
        hidden: th.Tensor = None,
    ) -> Tuple[th.Tensor, th.Tensor]:
        # Forward pass through policy network to get action probabilities
        new_hidden = None

        if self.reward_conditioned:
            if self.recurrent_policy:
                probs, new_hidden = self.policy_net(
                    state, accrued_reward, prev_action, hidden
                )
            else:
                probs = self.policy_net(state, accrued_reward)
        else:
            if self.recurrent_policy:
                probs, new_hidden = self.policy_net(state, prev_action, hidden)
            else:
                probs = self.policy_net(state)
        m = Categorical(logits=probs)
        action = m.sample()
        log_prob = m.log_prob(action)
        entropy = m.entropy()
        return action, (log_prob,entropy), new_hidden

    def train(
        self,
        timesteps: int,
        eval_env: Any,
        eval_freq: int = 100,
        n_evals: int = 32,
        log: bool = False,
        logger: WandbLogger = None,
        other_metrics: Dict[str, Callable] = {},
        save_results: bool = False,
        results_dir: str = None,
    ) -> None:

        def log_agent_performances():
            returns, vec_returns, other_metrics_values, episode_lengths = (
                self.eval(eval_env, n_evals, other_metrics)
            )
            returns, vec_returns, episode_lengths = map(
                th.stack, [returns, vec_returns, episode_lengths]
            )
          
            avg_scalarized_return = th.mean(returns)

            avg_episode_length = th.mean(episode_lengths)
            if log:
                if logger is None:
                    raise AttributeError(
                        "Logging enabled but no logger assigned"
                    )
                to_log = {
                    "timestep": self.curr_step,
                    "mean reward": avg_scalarized_return,
                    "mean episode_length": avg_episode_length,
                }
                for i in other_metrics_values:
                    to_log.update({i : other_metrics_values[i].mean()})
                for i in range(len(vec_returns)):
                    to_log.update(
                        {
                            f"exp{i}/objective_{j}": vec_returns[i, j]
                            for j in range(self.reward_dim)
                        }
                    )
                logger.log(to_log)
            else:
                print(returns)
                print(vec_returns)
                print(th.median(vec_returns, axis=1))
                print(avg_episode_length)
                print(avg_scalarized_return)
                print(other_metrics_values)
                # print(sum(returns), np.mean(returns))
            if (
                self.best_score is None
                or self.best_score < avg_scalarized_return
            ):
                self.best_score = avg_scalarized_return
                self.best_policy = deepcopy(self.policy_net.state_dict())

            if save_results:
                if not os.path.exists(f"{results_dir}/{self.curr_step}"):
                    os.makedirs(f"{results_dir}/{self.curr_step}")
                pd.DataFrame(episode_lengths.numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/episode_lengths.csv",  index=False
                )
                pd.DataFrame(other_metrics_values["curr_proportion"].numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/curr_proportion.csv",  index=False
                )
                pd.DataFrame(other_metrics_values["best_proportion"].numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/best_proportion.csv",  index=False
                )
                pd.DataFrame(returns.numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/returns.csv" ,  index=False
                )
                pd.DataFrame(vec_returns.numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/vec_returns.csv",  index=False
                )

        while self.curr_step < timesteps:
            state, info = self.env.reset()

            state = th.tensor(
                state, dtype=th.float32, device=self.device
            ).reshape(1, -1)
            accrued_reward = th.zeros(
                (1, self.reward_dim), dtype=th.float32, device=self.device
            )
            done = False
            hidden = self.policy_net.init_hidden().to(self.device)
            prev_action = None
            entropy = []
            buffer = []
            old_log_probs = []
            ep_length = 0
            while not done:
                ep_length += 1
                if eval_env and self.curr_step % eval_freq == 0:
                    log_agent_performances()
                buffer.append((state, accrued_reward,prev_action, hidden))
                action, (log_prob,ent), hidden = self.act(
                    state, accrued_reward,prev_action, hidden
                )
                old_log_probs.append(log_prob)
                entropy.append(ent)
                observation, reward, terminated, truncated, _ = self.env.step(
                    action.item()
                )
                reward = th.tensor(
                    reward[0],
                    dtype=th.float32,
                    device=self.device,
                )
                # reward = th.tensor(reward,device=self.device)
                accrued_reward += reward
                done = terminated or truncated
                prev_action = action
                next_state = th.tensor(
                    observation, dtype=th.float32, device=self.device
                ).reshape(1, -1)

                # Store experience in the buffer
                self.on_policy_trajectory_buffer.append(
                    (state, action, log_prob, reward, accrued_reward)
                )

                # Move to the next state
                state = next_state
                self.curr_step += 1

            # Perform policy update after each episode
            self.update()


            #compute the KL divergence between new and old policy
            new_log_probs = []
            for e in buffer:
                _, (log_prob,_), _ = self.act(*e)
                new_log_probs.append(log_prob)
            
            to_log = {
                    "policy_entropy":th.tensor(entropy).mean(),
                    "one_step_kl_divergence":(th.tensor(old_log_probs) - th.tensor(new_log_probs)).abs().mean(),
                    "train_episode_length": ep_length, 
                    "train episode scalarised reward": self.scalarization(accrued_reward, self.scalarization_weights)  
                }
            
            if(logger):
                logger.log(to_log)
            else:
                print(to_log)
        log_agent_performances()

    def eval(self, eval_env, n_evals, metrics: Dict[str, Callable]):
        scalarazied_rewards, vec_rewards = [], []
        episode_lengths = []
        demands, total_bag_size = [], 0
        for i in range(n_evals):
            obs, info = eval_env.reset()
            demands.append(th.from_numpy(info["total_demand"]))
            total_bag_size = sum(info["agents_bags_size"]) 
            done = False
            unnormalized_acc_reward = th.zeros(
                (1, self.reward_dim), dtype=th.float32, device="cpu"
            )
            acc_reward = th.zeros(
                (1, self.reward_dim), dtype=th.float32, device=self.device
            )
            state = th.tensor(
                obs, dtype=th.float32, device=self.device
            ).reshape(1, -1)
            hidden = self.policy_net.init_hidden().to(self.device)
            ep_length = 0
            prev_action = None
            while not done:
                action, _, hidden = self.act(state, acc_reward,prev_action, hidden)
                observation, reward, terminated, truncated, info = (
                    eval_env.step(action.item())
                )
                reward = th.tensor(reward[0], device=self.device)
                unnormalized_acc_reward += info["joint_reward"]
                acc_reward += reward
                done = terminated or truncated
                prev_action = action
                state = th.tensor(
                    observation, dtype=th.float32, device=self.device
                ).reshape(1, -1)
                ep_length += 1
            episode_lengths.append(th.tensor([ep_length], dtype=th.float32))
            vec_rewards.append(unnormalized_acc_reward.squeeze())
            scalarazied_rewards.append(
                self.scalarization(
                    unnormalized_acc_reward, self.scalarization_weights
                )
            )
        metrics_values = {}
        curr_proportions = min_proportion(th.stack(vec_rewards), th.stack(demands))
        best_proportions = max_min_proportion(th.stack(demands), total_bag_size)
        metrics_values["curr_proportion"] = curr_proportions
        metrics_values["best_proportion"] = best_proportions
        
        metrics_values["proportion"] = (curr_proportions / best_proportions).mean() 
        return scalarazied_rewards, vec_rewards, metrics_values, episode_lengths

    def __compute_returns(self, trajectory):
        # Compute returns
        previous_returns = deque([])
        incoming_returns = deque([])
        G = th.zeros(self.reward_dim, device=self.device)
        t = 0
        previous_returns.append(G)
        for _, _, _, reward, _ in list(trajectory)[: len(trajectory) - 1]:
            G = G + reward * self.gamma**t
            previous_returns.append(G)
            t += 1

        # compute t+
        G = th.zeros(self.reward_dim, device=self.device)

        for _, _, _, reward, _ in reversed(trajectory):
            G = reward + self.gamma * G
            incoming_returns.appendleft(G)

        # Convert returns to tensors

        previous_returns = th.stack(list(previous_returns))
        incoming_returns = th.stack(list(incoming_returns))

        # compute total returns
        returns = (
            previous_returns + incoming_returns
            if self.reward_conditioned
            else incoming_returns
        )
        returns = self.scalarization(returns, self.scalarization_weights)
        return returns

    def __compute_actor_loss_from_trajectory(self, trajectory):
        returns = self.__compute_returns(trajectory)
        # Compute loss and perform gradient descent
        policy_loss = 0
        for (state, _, log_prob, _, acc_reward), G in zip(trajectory, returns):
            # if self.use_baseline:
            #     policy_loss += -log_prob * (
            #         G - self.baseline_net(state, acc_reward)
            #     )
            # elif self.standardization and std_G != 0:
            #     policy_loss += -log_prob * (G - mean_G) / std_G
            # else:
            policy_loss += -log_prob * G

        return policy_loss.mean()

    def __compute_baseline_loss_from_trajectory(self, trajectory):
        returns = self.__compute_returns(trajectory)
        states = [s.squeeze() for s, _, _, _, _ in trajectory]
        acc_rewards = [
            acc_reward.squeeze() for _, _, _, _, acc_reward in trajectory
        ]
        acc_rewards = th.stack(acc_rewards)
        states = th.stack(states)
        state_values = self.baseline_net(states, acc_rewards).to(self.device)
        criterion = th.nn.MSELoss()
        return criterion(state_values, returns.unsqueeze(1))

    def update(self, trajectories=None):
        if trajectories is None:
            policy_loss = self.__compute_actor_loss_from_trajectory(
                self.on_policy_trajectory_buffer
            )
        else:
            policy_loss = sum(self.__compute_actor_loss_from_trajectory(
                trajectories[i]
            ) for i in range(len(trajectories)))
        # Clear gradients
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()


        if self.use_baseline:
            self.baseline_optimizer.zero_grad()
            if trajectories is None:
                baseline_loss = self.__compute_baseline_loss_from_trajectory(
                    self.on_policy_trajectory_buffer
                )
            else:
                policy_loss = th.mean(
                    [
                        self.__compute_baseline_loss_from_trajectory(
                            trajectories[i]
                        )
                        for i in range(len(trajectories))
                    ]
                )

            baseline_loss.backward()
            self.baseline_optimizer.step()
        
        # Clear buffer after update
        self.on_policy_trajectory_buffer.clear()
        if trajectories is not None:
            for trajectory in trajectories:
                trajectory.clear()

    def play_best_policy(self, env):
        self.policy_net.load_state_dict(self.best_policy)
        obs, _ = env.reset()
        done = False
        state = th.tensor(obs, dtype=th.float32, device=self.device).unsqueeze(
            0
        )
        hidden = self.policy_net.init_hidden()
        while not done:
            action, _, hidden = self.act(state, hidden)
            obs, _, terminated, truncated, _ = env.step(action.item())
            env.render()
            done = terminated or truncated
            state = th.tensor(
                obs, dtype=th.float32, device=self.device
            ).unsqueeze(0)

    def save_best_policy(self, path) -> None:
        th.save(self.best_policy, path)

    def play_saved_policy(
        self, env, path, save_video=False, save_dir=None
    ) -> None:
        checkpoint = th.load(path, weights_only=True)
        self.policy_net.load_state_dict(checkpoint)
        obs, _ = env.reset()
        done = False
        state = th.tensor(obs, dtype=th.float32, device=self.device).unsqueeze(
            0
        )
        hidden = self.policy_net.init_hidden()
        acc_reward = th.zeros(
            (1, self.reward_dim), dtype=th.float32, device=self.device
        )
        i = 0
        prev_action = None
        while not done:
            action, _, hidden = self.act(state, acc_reward, prev_action, hidden)
            obs, r, terminated, truncated, _ = env.step(action.item())
            reward = th.tensor(r[0], device=self.device)

            acc_reward += reward
            arr = env.render(mode="rgb_array")
            import matplotlib.pyplot as plt

            plt.imshow(arr, origin="upper", extent=[0, 1, 0, 1])
            plt.axis("off")
            plt.savefig(f"videos/{i}.png", bbox_inches="tight")
            done = terminated or truncated
            prev_action = action
            state = th.tensor(
                obs, dtype=th.float32, device=self.device
            ).unsqueeze(0)
            i += 1
