from __future__ import annotations

import os

import sys
import time
import random
from collections import deque
import numpy as np
import torch
import torch.nn as nn
import torch.optim
from torch.distributions import Normal
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader, TensorDataset


from spinup.utils.logx import EpochLogger
import spinup.algos.vanilla.cup.core as core
from spinup.utils.mpi_pytorch import setup_pytorch_for_mpi
from spinup.utils.mpi_tools import num_procs

class VectorizedOnPolicyBuffer:
    """
    A buffer for storing vectorized on-policy data for reinforcement learning.

    Args:
        obs_space (gymnasium.Space): The observation space.
        act_space (gymnasium.Space): The action space.
        size (int): The maximum size of the buffer.
        gamma (float, optional): The discount factor for rewards. Defaults to 0.99.
        lam (float, optional): The lambda parameter for GAE computation. Defaults to 0.95.
        lam_c (float, optional): The lambda parameter for cost GAE computation. Defaults to 0.95.
        standardized_adv_r (bool, optional): Whether to standardize advantage rewards. Defaults to True.
        standardized_adv_c (bool, optional): Whether to standardize advantage costs. Defaults to True.
        device (torch.device, optional): The device to store tensors on. Defaults to "cpu".
        num_envs (int, optional): The number of parallel environments. Defaults to 1.
    """
    def __init__(
        self,
        obs_space,
        act_space,
        size: int,
        gamma: float=0.99,
        lam: float=0.95,
        lam_c: float=0.95,
        standardized_adv_r: bool = True,
        standardized_adv_c: bool = True,
        device: torch.device = "cpu",
        num_envs: int = 1,
    ) -> None:
        self.buffers: list[dict[str, torch.tensor]] = [
            {
                "obs": torch.zeros(
                    (size, *obs_space.shape), dtype=torch.float32, device=device
                ),
                "act": torch.zeros(
                    (size, *act_space.shape), dtype=torch.float32, device=device
                ),
                "reward": torch.zeros(size, dtype=torch.float32, device=device),
                "cost": torch.zeros(size, dtype=torch.float32, device=device),
                "done": torch.zeros(size, dtype=torch.float32, device=device),
                "value_r": torch.zeros(size, dtype=torch.float32, device=device),
                "value_c": torch.zeros(size, dtype=torch.float32, device=device),
                "adv_r": torch.zeros(size, dtype=torch.float32, device=device),
                "adv_c": torch.zeros(size, dtype=torch.float32, device=device),
                "target_value_r": torch.zeros(size, dtype=torch.float32, device=device),
                "target_value_c": torch.zeros(size, dtype=torch.float32, device=device),
                "log_prob": torch.zeros(size, dtype=torch.float32, device=device),
            }
            for _ in range(num_envs)
        ]
        self._gamma = gamma
        self._lam = lam
        self._lam_c = lam_c
        self._standardized_adv_r = standardized_adv_r
        self._standardized_adv_c = standardized_adv_c
        self.ptr_list = [0] * num_envs
        self.path_start_idx_list = [0] * num_envs
        self._device = device
        self.num_envs = num_envs

    def store(self, **data: torch.Tensor) -> None:
        """
        Store vectorized data into the buffer.

        Args:
            **data: Keyword arguments specifying data tensors to be stored.
        """
        for i, buffer in enumerate(self.buffers):
            assert self.ptr_list[i] < buffer["obs"].shape[0], "Buffer overflow"
            for key, value in data.items():
                buffer[key][self.ptr_list[i]] = value[i]
            self.ptr_list[i] += 1

    def finish_path(
        self,
        last_value_r: torch.Tensor | None = None,
        last_value_c: torch.Tensor | None = None,
        idx: int = 0,
    ) -> None:
        """
        Finalize the trajectory path and compute advantages and value targets.

        Args:
            last_value_r (torch.Tensor, optional): The last value estimate for rewards. Defaults to None.
            last_value_c (torch.Tensor, optional): The last value estimate for costs. Defaults to None.
            idx (int, optional): Index of the environment. Defaults to 0.
        """
        if last_value_r is None:
            last_value_r = torch.zeros(1, device=self._device)
        if last_value_c is None:
            last_value_c = torch.zeros(1, device=self._device)
        path_slice = slice(self.path_start_idx_list[idx], self.ptr_list[idx])
        last_value_r = last_value_r.to(self._device)
        last_value_c = last_value_c.to(self._device)
        rewards = torch.cat([self.buffers[idx]["reward"][path_slice], last_value_r])
        costs = torch.cat([self.buffers[idx]["cost"][path_slice], last_value_c])
        values_r = torch.cat([self.buffers[idx]["value_r"][path_slice], last_value_r])
        values_c = torch.cat([self.buffers[idx]["value_c"][path_slice], last_value_c])

        adv_r, target_value_r = core.calculate_adv_and_value_targets(
            values_r,
            rewards,
            lam=self._lam,
            gamma=self._gamma,
        )
        adv_c, target_value_c = core.calculate_adv_and_value_targets(
            values_c,
            costs,
            lam=self._lam_c,
            gamma=self._gamma,
        )
        self.buffers[idx]["adv_r"][path_slice] = adv_r
        self.buffers[idx]["adv_c"][path_slice] = adv_c
        self.buffers[idx]["target_value_r"][path_slice] = target_value_r
        self.buffers[idx]["target_value_c"][path_slice] = target_value_c

        self.path_start_idx_list[idx] = self.ptr_list[idx]

    def get(self) -> dict[str, torch.Tensor]:
        """
        Retrieve collected data from the buffer.

        Returns:
            dict[str, torch.Tensor]: A dictionary containing collected data tensors.
        """
        data_pre = {k: [v] for k, v in self.buffers[0].items()}
        for buffer in self.buffers[1:]:
            for k, v in buffer.items():
                data_pre[k].append(v)
        data = {k: torch.cat(v, dim=0) for k, v in data_pre.items()}
        adv_mean = data["adv_r"].mean()
        adv_std = data["adv_r"].std()
        cadv_mean = data["adv_c"].mean()
        if self._standardized_adv_r:
            data["adv_r"] = (data["adv_r"] - adv_mean) / (adv_std + 1e-8)
        if self._standardized_adv_c:
            data["adv_c"] = data["adv_c"] - cadv_mean
        self.ptr_list = [0] * self.num_envs
        self.path_start_idx_list = [0] * self.num_envs

        return data


def cup(env_fn, actor_critic=core.ActorVCritic, ac_kwargs=dict(), seed=0, device="cuda:0", 
        steps_per_epoch=4000, epochs=50, gamma=0.99, max_grad_norm=40, cup_lambda=0.95, cup_nu=0.2,
        train_pi_iters=80, train_v_iters=80, max_ep_len=1000, num_envs=1,
        target_kl=0.01, logger_kwargs=dict(), save_freq=100):

    # Special function to avoid certain slowdowns from PyTorch + MPI combo.
    setup_pytorch_for_mpi()

    # Set up logger and save configuration
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # set the random seed, device and number of threads
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device(device)
    env = env_fn()

    # set training steps
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    # create the actor-critic module
    policy = actor_critic(
        obs_dim=env.observation_space.shape[0],
        act_dim=env.action_space.shape[0],
        hidden_sizes=ac_kwargs["hidden_sizes"],
    ).to(device)
    actor_optimizer = torch.optim.Adam(policy.actor.parameters(), lr=3e-4)
    actor_scheduler = LinearLR(
        actor_optimizer,
        start_factor=1.0,
        end_factor=0.0,
        total_iters=epochs,
        verbose=False,
    )
    reward_critic_optimizer = torch.optim.Adam(
        policy.reward_critic.parameters(), lr=1e-3
    )
    cost_critic_optimizer = torch.optim.Adam(
        policy.cost_critic.parameters(), lr=1e-3
    )
    logger.setup_pytorch_saver(policy)
    # create the vectorized on-policy buffer
    buffer = VectorizedOnPolicyBuffer(
        obs_space=env.observation_space,
        act_space=env.action_space,
        size=local_steps_per_epoch,
        device=device,
        num_envs=1,
        gamma=gamma,
    )
    # setup lagrangian multiplier
    lagrange = core.Lagrange(
        cost_limit=10.0,
        lagrangian_multiplier_init=0.001,
        lagrangian_multiplier_lr=0.035,
        lagrangian_upper_bound=cup_nu,
    )

    
    rew_deque = deque(maxlen=1)
    cost_deque = deque(maxlen=1)
    len_deque = deque(maxlen=1)

    logger.log("Start with training.")
    obs, _ = env.reset()
    obs = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
    ep_ret, ep_cost, ep_len = (
        np.zeros(num_envs),
        np.zeros(num_envs),
        np.zeros(num_envs),
    )
    # training loop
    for epoch in range(epochs):
        rollout_start_time = time.time()
        # collect samples until we have enough to update
        for steps in range(local_steps_per_epoch):
            with torch.no_grad():
                act, log_prob, value_r, value_c = policy.step(obs, deterministic=False)
            action = act.detach().squeeze().cpu().numpy()
            next_obs, reward, terminated, infos = env.step(action)
            # print("terminated", terminated)
            cost = infos['constraint_violation']
            ep_ret += reward
            ep_cost += infos['constraint_violation']
            ep_len += 1
            next_obs, reward, cost, terminated = (
                torch.as_tensor(x, dtype=torch.float32, device=device).unsqueeze(0)
                for x in (next_obs, reward, cost, terminated)
            )
            
            buffer.store(
                obs=obs,
                act=act,
                reward=reward,
                cost=cost,
                value_r=value_r,
                value_c=value_c,
                log_prob=log_prob,
            )

            obs = next_obs
            timeout = ep_len == max_ep_len
            epoch_end = steps >= local_steps_per_epoch - 1

            for idx, done in enumerate(terminated):
                if epoch_end or done or timeout:
                    last_value_r = torch.zeros(1, device=device)
                    last_value_c = torch.zeros(1, device=device)
                    if not done:
                        with torch.no_grad():
                            _, _, last_value_r, last_value_c = policy.step(
                                obs[idx], deterministic=False
                            )
    
                        last_value_r = last_value_r.unsqueeze(0)
                        last_value_c = last_value_c.unsqueeze(0)
                    if done:
                        rew_deque.append(ep_ret[idx])
                        cost_deque.append(ep_cost[idx])
                        len_deque.append(ep_len[idx])
                        logger.store(
                            **{
                                "EpRet": np.mean(rew_deque),
                                "EpRisk": np.mean(cost_deque),
                                "EpLen": np.mean(len_deque),
                            }
                        )
                        ep_ret[idx] = 0.0
                        ep_cost[idx] = 0.0
                        ep_len[idx] = 0.0
                        logger.logged = False
                    
                    buffer.finish_path(last_value_r=last_value_r, last_value_c=last_value_c, idx=idx)
                    obs = env.reset()[0]
                    obs = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
                    ep_ret, ep_cost, ep_len = (
                                                np.zeros(num_envs),
                                                np.zeros(num_envs),
                                                np.zeros(num_envs),
                                            )
        rollout_end_time = time.time()

        # update lagrange multiplier
        ep_costs = logger.get_cost_stats("EpRisk")
        lagrange.update_lagrange_multiplier(ep_costs)

        # update policy
        data = buffer.get()
        old_distribution = policy.actor(data["obs"])

        # comnpute advantage
        advantage = data["adv_r"]

        dataloader = DataLoader(
            dataset=TensorDataset(
                data["obs"],
                data["act"],
                data["log_prob"],
                data["target_value_r"],
                data["target_value_c"],
                advantage,
            ),
            batch_size=64,
            shuffle=True,
        )
        update_counts = 0
        final_kl = torch.ones_like(old_distribution.loc)

        # the first stage update is the same as the original PPO
        for _ in range(train_pi_iters):
            for (
                obs_b,
                act_b,
                log_prob_b,
                target_value_r_b,
                target_value_c_b,
                adv_b,
            ) in dataloader:
                reward_critic_optimizer.zero_grad()
                loss_r = nn.functional.mse_loss(policy.reward_critic(obs_b), target_value_r_b)
                cost_critic_optimizer.zero_grad()
                loss_c = nn.functional.mse_loss(policy.cost_critic(obs_b), target_value_c_b)
            
                distribution = policy.actor(obs_b)
                log_prob = distribution.log_prob(act_b).sum(dim=-1)
                ratio = torch.exp(log_prob - log_prob_b)
                ratio_cliped = torch.clamp(ratio, 0.8, 1.2)
                loss_pi = -torch.min(ratio * adv_b, ratio_cliped * adv_b).mean()
                actor_optimizer.zero_grad()
                total_loss = loss_pi + loss_r + loss_c
                total_loss.backward()
                # clip_grad_norm_(policy.parameters(), max_grad_norm)
                reward_critic_optimizer.step()
                cost_critic_optimizer.step()
                actor_optimizer.step()

                logger.store(
                    **{
                        "Loss_reward_critic": loss_r.mean().item(),
                        "Loss_cost_critic": loss_c.mean().item(),
                        "Loss_actor": loss_pi.mean().item(),
                    }
                )

            new_distribution = policy.actor(data["obs"])
            kl = (
                torch.distributions.kl.kl_divergence(old_distribution, new_distribution)
                .sum(-1, keepdim=True)
                .mean()
                .item()
            )
            final_kl = kl
            update_counts += 1
            if kl > 1.5 * target_kl:
                break

        with torch.no_grad():
            old_distribution = policy.actor(data["obs"])
            old_mean = old_distribution.mean
            old_std = old_distribution.stddev

        advantage = data["adv_c"]

        dataloader = DataLoader(
            dataset=TensorDataset(
                data["obs"], data["act"], data["log_prob"], advantage, old_mean, old_std
            ),
            batch_size=64,
            shuffle=True,
        )

        update_counts_2 = 0
        for i in range(train_v_iters):
            for obs_b, act_b, log_prob_b, adv_b, old_mean_b, old_std_b in dataloader:
                old_distribution_b = Normal(old_mean_b, old_std_b)
                distribution = policy.actor(obs_b)
                log_prob = distribution.log_prob(act_b).sum(dim=-1)
                ratio = torch.exp(log_prob - log_prob_b)
                temp_kl = torch.distributions.kl_divergence(
                    distribution, old_distribution_b
                ).sum(-1, keepdim=True)
                coef = (1 - gamma * cup_lambda) / (1 - gamma)
                loss_pi_cost = (
                    lagrange.lagrangian_multiplier * coef * ratio * adv_b + temp_kl
                ).mean()
                actor_optimizer.zero_grad()
                loss_pi_cost.backward()
                # clip_grad_norm_(policy.actor.parameters(), max_grad_norm)
                actor_optimizer.step()

            new_distribution = policy.actor(data["obs"])

            kl = (
                torch.distributions.kl.kl_divergence(old_distribution, new_distribution)
                .sum(-1, keepdim=True)
                .mean()
                .item()
            )
            final_kl = kl
            update_counts_2 += 1
            if kl > 1.5 * target_kl:
                logger.log(
                    f"Early stopping at iter {i + 1} due to reaching max kl at second stage"
                )
                break

        update_end_time = time.time()
        actor_scheduler.step()
        if not logger.logged:
            # log data
            logger.log_tabular("EpRet", average_only=True)
            logger.log_tabular("EpRisk", average_only=True)
            logger.log_tabular("EpLen", average_only=True)
            logger.log_tabular("Epoch", epoch + 1)
            logger.log_tabular("TotalSteps", (epoch + 1) * steps_per_epoch)
            logger.log_tabular("StopIter", update_counts)
            logger.log_tabular("SeconStageStopIter", update_counts_2)
            logger.log_tabular("KL", final_kl)
            logger.log_tabular("LagragianMultiplier", lagrange.lagrangian_multiplier)
            logger.log_tabular("LR", actor_scheduler.get_last_lr()[0])
            logger.log_tabular("Loss_reward_critic", average_only=True)
            logger.log_tabular("Loss_cost_critic", average_only=True)
            logger.log_tabular("Loss_actor", average_only=True)
            logger.log_tabular("Rollout", rollout_end_time - rollout_start_time)
            logger.log_tabular("Total", update_end_time - rollout_start_time)
            logger.log_tabular("RewardAdv", data["adv_r"].mean().item())
            logger.log_tabular("CostAdv", data["adv_c"].mean().item())

            logger.dump_tabular()
            if (epoch % save_freq == 0) or (epoch == epochs-1):
                logger.save_state({'env': env}, epoch)

