import os
import time
from collections import deque
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch as th
import torch.multiprocessing as mp

from common import DecentralizedRLAlgorithm, WandbLogger, min_proportion, max_min_proportion

from .eupg_agent import EUPG


class DecEUPG(DecentralizedRLAlgorithm):
    def __init__(
        self,
        n_agents,
        agent_rewards: List[list[int]],
        env,
        hidden_layers: List[int],
        scalarization: Callable,
        scalarization_weights: Union[List[float], th.Tensor, np.ndarray],
        device: str = "cpu",
        gamma: float = 0.9,
        lr: float = 1e-3,
        standardization: bool = True,
        use_baseline: bool = False,
        experience_sharing: bool = False,
        weight_sharing: bool = False,
        recurrent_policy: bool = False,
        reward_conditioned: bool = True,
    ) -> None:
        self.n_agents = n_agents

        if weight_sharing:
            self.agent: EUPG = EUPG(
                len(agent_rewards[0]),
                env,
                env.observation_space[0].shape[0],
                hidden_layers,
                env.action_space[0].n,
                scalarization,
                scalarization_weights,
                device,
                standardization=standardization,
                use_baseline=use_baseline,
                recurrent_policy=recurrent_policy,
                reward_conditioned=reward_conditioned,
            )
            self.trajectories_buffer = [deque([]) for _ in range(self.n_agents)]
        else:
            self.agents: List[EUPG] = [
                EUPG(
                    len(agent_rewards[i]),
                    env,
                    env.observation_space[i].shape[0],
                    hidden_layers,
                    env.action_space[i].n,
                    scalarization,
                    scalarization_weights,
                    device,
                    standardization=standardization,
                    use_baseline=use_baseline,
                    recurrent_policy=recurrent_policy,
                    reward_conditioned=reward_conditioned,
                )
                for i in range(self.n_agents)
            ]
        self.agents_rewards = agent_rewards 
        self.curr_step = 0
        self.env = env
        self.reward_dim = self.env.reward_space.shape[0]
        self.device = device
        self.gamma = gamma
        self.learning_rate = lr
        self.weight_sharing = weight_sharing
        self.experience_sharing = experience_sharing
        self.scalarization = scalarization
        self.scalarization_weights = scalarization_weights
        self.best_score = None
        self.best_policy = None

    def act(self, states, cumm_rewards, prev_joint_action, hiddens, dones):
        joint_action = []
        log_probs = []
        new_hiddens = []
        if self.weight_sharing:
            for i in range(self.n_agents):
                if(not dones[i]):
                    action, (log_prob,entropy), hidden = self.agent.act(
                        states[i], cumm_rewards[i, prev_joint_action[i]], hiddens[i]
                    )
                else:
                    action = None
                    log_prob = None
                    hidden = None
                joint_action.append(action)
                log_probs.append(log_prob)
                new_hiddens.append(hidden)
            
                    
        else:
            for i, agent in enumerate(self.agents):
                # Forward pass through policy network to get action probabilities
                if(not dones[i]):
                    action, (log_prob,entropy), hidden = agent.act(
                        states[i], cumm_rewards[i], prev_joint_action[i], hiddens[i]
                    )
                else:
                    action = None
                    log_prob = None
                    hidden = None
                joint_action.append(action)
                log_probs.append(log_prob)
                new_hiddens.append(hidden)
        
        return joint_action, log_probs, new_hiddens

    def train(
        self,
        timesteps: int,
        eval_env: object,
        eval_freq: int = 100,
        n_evals: int = 32,
        log: bool = False,
        logger: Optional[WandbLogger] = None,
        other_metrics: Optional[Dict[str, Callable]] = {},
        save_results: bool = False,
        results_dir: str = None,
    ):
        def log_agents_performance():
            (
                returns,
                vec_returns,
                other_metrics_values,
                agents_bags,
                episode_lengths,
            ) = self.eval(eval_env, n_evals, other_metrics)
            returns, vec_returns, agents_bags, episode_lengths = map(
                th.stack, [returns, vec_returns, agents_bags, episode_lengths]
            )
            avg_scalarized_return = th.mean(returns)
            avg_episode_length = th.mean(episode_lengths)
            if log:
                if logger is None:
                    raise AttributeError(
                        "Logging enabled but no logger assigned"
                    )
                to_log = {
                    "timestep": self.curr_step,
                    "mean reward": avg_scalarized_return,
                    "mean episode_length": avg_episode_length,
                }
                for i in other_metrics_values:
                    to_log.update({i : other_metrics_values[i].mean()})
                
                for i in range(len(vec_returns)):
                    to_log.update(
                        {
                            f"exp{i}/objective_{j}": vec_returns[i, j]
                            for j in range(self.reward_dim)
                        }
                    )
                for i in range(len(agents_bags)):
                    for k in range(self.n_agents):
                        to_log.update(
                            {
                                    f"exp{i}/agents{k}/objective_{j}": agents_bags[
                                    i, k, ind
                                ]
                                for ind,j in enumerate(self.agents_rewards[k])
                            }
                        )
                logger.log(to_log)
            else:
                print(returns)
                print(vec_returns)
                print(th.median(vec_returns, axis=1).values)
                print(avg_scalarized_return)
                print(avg_episode_length)
                print(other_metrics_values)

            if (
                self.best_score is None
                or self.best_score < avg_scalarized_return
            ):
                self.best_score = avg_scalarized_return
                if self.weight_sharing:
                    self.best_policy = deepcopy(
                        self.agent.policy_net.state_dict()
                    )

                else:
                    self.best_policy = [
                        deepcopy(self.agents[i].policy_net.state_dict())
                        for i in range(self.n_agents)
                    ]
            if save_results:
                if not os.path.exists(f"{results_dir}/{self.curr_step}"):
                    os.makedirs(f"{results_dir}/{self.curr_step}")

                pd.DataFrame(episode_lengths.numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/episode_lengths.csv", index=False
                )
              
                pd.DataFrame(other_metrics_values["curr_proportion"].numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/curr_proportion.csv", index=False
                )
                pd.DataFrame(other_metrics_values["best_proportion"].numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/best_proportion.csv", index=False
                )
                pd.DataFrame(returns.numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/returns.csv", index=False
                )
                pd.DataFrame(vec_returns.numpy()).to_csv(
                    f"{results_dir}/{self.curr_step}/vec_returns.csv", index=False
                )
                for i in range(self.n_agents):
                    curr_agent_bag = agents_bags[:, i, :]
                    pd.DataFrame(curr_agent_bag.numpy()).to_csv(
                        f"{results_dir}/{self.curr_step}/agent_{i}_bag.csv", index=False
                    )

        while self.curr_step < timesteps:
            states, info = self.env.reset()

            states = [
                th.tensor(
                    states[i], dtype=th.float32, device=self.device
                ).reshape(1, -1)
                for i in range(self.n_agents)
            ]
            if self.weight_sharing:
                cumm_rewards = [
                    th.zeros(
                        (1, self.agent.reward_dim),
                        dtype=th.float32,
                        device=self.device,
                    )
                    for _ in range(self.n_agents)
                ]

            else:
                cumm_rewards = [
                    th.zeros(
                        (1, a.reward_dim),
                        dtype=th.float32,
                        device=self.device,
                    )
                    for a in self.agents
                ]

            dones = [False] * self.n_agents
            hiddens = [
                self.agents[i].policy_net.init_hidden().to(self.device)
                for i in range(self.n_agents)
            ]
            prev_joint_action = [None for _ in range(self.n_agents)]
            while not all(dones):
                if eval_env and self.curr_step % eval_freq == 0:
                    log_agents_performance()
                    input()
                joint_action, log_probs, hiddens = self.act(
                    states, cumm_rewards, prev_joint_action, hiddens, dones
                )
                observations, rewards, terminated, truncated, _ = self.env.step(
                    joint_action
                )
                for i in range(self.n_agents):
                    if(not dones[i]):
                        r = th.tensor(
                            rewards[i], dtype=th.float32, device=self.device
                        )
                        cumm_rewards[i] += r
                        dones[i] = terminated[i] or truncated
                        prev_joint_action[i] = joint_action[i]
                        # Store experience in the buffer
                        agent_experience = (
                            states[i],
                            joint_action[i],
                            log_probs[i],
                            r,
                            cumm_rewards[i],
                        )
                        if self.weight_sharing:
                            self.trajectories_buffer[i].append(agent_experience)
                        else:
                            self.agents[i].on_policy_trajectory_buffer.append(
                                agent_experience
                            )

                # Move to the next state
                next_states = [
                    th.tensor(
                        observations[i], dtype=th.float32, device=self.device
                    ).reshape(1, -1)
                    for i in range(self.n_agents)
                ]
                states = next_states
                self.curr_step += 1

            # Perform policy update after each episode
            self.update()
        log_agents_performance()

    def eval(self, eval_env, n_evals, metrics: Dict[str, Callable]):

        def one_episode_eval():
            obs, info = eval_env.reset()
            total_demand = info["total_demand"]
            total_bag_size = sum(info["agents_bags_size"])
            dones = [False] * self.n_agents
            if self.weight_sharing:
                cumm_rewards = [
                    th.zeros(
                        (1, self.agent.reward_dim),
                        dtype=th.float32,
                        device=self.device,
                    )
                    for _ in range(self.n_agents)
                ]

            else:
                cumm_rewards = [
                    th.zeros(
                        (1, a.reward_dim),
                        dtype=th.float32,
                        device=self.device,
                    )
                    for a in self.agents
                ]
            states = [
                th.tensor(obs[i], dtype=th.float32, device=self.device).reshape(
                    1, -1
                )
                for i in range(self.n_agents)
            ]
            joint_reward = th.zeros((1, self.reward_dim))

            hiddens = [
                self.agents[i].policy_net.init_hidden().to(self.device)
                for i in range(self.n_agents)
            ]
            ep_length = 0
            prev_joint_action = [None for _ in range(self.n_agents)]
            while not all(dones):
                joint_action, _, hiddens = self.act(states, cumm_rewards, prev_joint_action, hiddens, dones)
                observations, rewards, terminated, truncated, info = (
                    eval_env.step(joint_action)
                )
                joint_reward += info["joint_reward"]
                # reward = th.tensor(reward, device=self.device)
                for i in range(self.n_agents):
                    cumm_rewards[i] += th.tensor(rewards[i], device=self.device)
                    prev_joint_action[i] = joint_action[i]
                    dones[i] = terminated[i] or truncated

                states = [
                    th.tensor(
                        observations[i], dtype=th.float32, device=self.device
                    ).reshape(1, -1)
                    for i in range(self.n_agents)
                ]
                ep_length += 1

            vec_reward = joint_reward.squeeze()
            scalarazied_reward = self.scalarization(
                joint_reward, self.scalarization_weights
            )
            agent_bags = th.tensor(
                [(eval_env.agents[i].bag) for i in range(self.n_agents)]
            )
            # print(scalarazied_reward, vec_reward, agent_bags, ep_length)
            return (
                scalarazied_reward,
                vec_reward,
                agent_bags,
                th.tensor([ep_length], dtype=th.float32),
                total_demand,
                total_bag_size
            )

        metrics_values = {}
        scalarazied_rewards, vec_rewards = [], []
        agents_bags = []
        episode_lengths = []
        demands = []
        for _ in range(n_evals):
            scalarazied_reward, vec_reward, agent_bags, ep_length, total_demand, total_bag_size = (
                one_episode_eval()
            )
            episode_lengths.append(ep_length)
            vec_rewards.append(vec_reward)
            scalarazied_rewards.append(scalarazied_reward)
            agents_bags.append(agent_bags)
            demands.append(th.from_numpy(total_demand))
        for m in metrics.keys():
            metrics_values[m] = metrics[m](th.stack(vec_rewards))
        curr_proportions = min_proportion(th.stack(vec_rewards), th.stack(demands))
        best_proportions = max_min_proportion(th.stack(demands), total_bag_size)
        metrics_values["curr_proportion"] = curr_proportions
        metrics_values["best_proportion"] = best_proportions 
        metrics_values["proportion"] = (curr_proportions / best_proportions) 
        return (
            scalarazied_rewards,
            vec_rewards,
            metrics_values,
            agents_bags,
            episode_lengths,
        )

    def update(self):
        if self.weight_sharing:
            self.agent.update(self.trajectories_buffer)
        else:
            for i in range(self.n_agents):
                self.agents[i].update()

    def save_best_policy(self, path) -> None:
        if self.weight_sharing:
            th.save(self.best_policy, path)
        else:
            if not os.path.exists(path):
                os.makedirs(path)

            for i in range(self.n_agents):
                th.save(self.best_policy, f"{path}/{i}.pt")

    def play_best_policy(self, env):
        for i in range(self.n_agents):
            self.agents[i].policy_net.load_state_dict(self.best_policy[i])
        obs, _ = env.reset()
        done = False
        state = [
            th.tensor(obs[i], dtype=th.float32, device=self.device).reshape(
                1, -1
            )
            for i in range(self.n_agents)
        ]
        cumm_rewards = [th.zeros(1, a.reward_dim) for a in self.agents]
        hiddens = [
            self.agents[i].policy_net.init_hidden()
            for i in range(self.n_agents)
        ]
        prev_joint_action = [None for _ in range(self.n_agents)]
        while not done:
            joint_action, _, hiddens = self.act(state, cumm_rewards,prev_joint_action, hiddens)
            obs, rewards, terminated, truncated, _ = env.step(joint_action)
            for i in range(self.n_agents):
                cumm_rewards[i] += rewards[i]
                prev_joint_action[i] = joint_action[i] 
            env.render()
            done = terminated or truncated
            state = [
                th.tensor(obs[i], dtype=th.float32, device=self.device).reshape(
                    1, -1
                )
                for i in range(self.n_agents)
            ]
            time.sleep(0.05)
