import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

from env_list import env_action_dict

class LossWithIntermediateLosses:
    def __init__(self, **kwargs):
        self.loss_total = sum(kwargs.values())
        self.intermediate_losses = {k: v.item() for k, v in kwargs.items()}

    def __truediv__(self, value):
        for k, v in self.intermediate_losses.items():
            self.intermediate_losses[k] = v / value
        self.loss_total = self.loss_total / value
        return self

class PPO(nn.Module):
    def __init__(self, action_size, ih):
        super().__init__()
        self.action_size = action_size
        self.critic = nn.Sequential(nn.Linear(512, 512), nn.ELU(), nn.Linear(512, 512), nn.ELU(), nn.Linear(512, 512), nn.ELU(),nn.Linear(512, 1))
        self.actor  = nn.Sequential(nn.Linear(512, 512), nn.ELU(), nn.Linear(512, 512), nn.ELU(), nn.Linear(512, 512), nn.ELU(), nn.Linear(512, action_size))

        self.ih = ih

        self.gamma = 0.99
        self.lambda_ = 0.95
        self.clip_coef = 0.2
        self.entropy_coef = 0.005
        self.vf_coef = 0.5
        self.clip_vf = 0.1
        self.target_kl = 0.01

    def get_output(self, cvs, env_names):
        als = self.actor(cvs)
        for i in range(len(als)):
            env_name = env_names[i]
            als[i, len(env_action_dict[env_name]):] = -1e9
        vals = self.critic(cvs)
        return (als, vals)

    @torch.no_grad()
    def compute_advantage(self, rewards, values, dones):
        rewards = rewards.float()
        not_done = 1.0 - dones.float()
        v_t, v_tp1 = values[:, :-1], values[:, 1:]
        deltas = rewards + self.gamma * not_done * v_tp1 - v_t

        B, T = rewards.shape
        adv = torch.zeros_like(rewards, dtype=values.dtype)
        gae = torch.zeros(B, dtype=values.dtype, device=values.device)
        for t in range(T - 1, -1, -1):
            gae = deltas[:, t] + self.gamma * self.lambda_ * not_done[:, t] * gae
            adv[:, t] = gae
        return adv

    def forward(self, cvs, actions, old_log_probs, lambda_return, advantage, old_vals, env_names):
        als, vals = self.get_output(cvs, env_names)
        dist = Categorical(logits=als)
        log_probs = dist.log_prob(actions)
        ratio = torch.exp(log_probs - old_log_probs)
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1.0 - self.clip_coef, 1.0 + self.clip_coef) * advantage
        loss_actions = -torch.min(surr1, surr2).mean()
        loss_entropy = - self.entropy_coef * dist.entropy().mean()

        vals = vals.squeeze(-1)
        vals_clipped = old_vals + (vals - old_vals).clamp(-self.clip_vf, self.clip_vf)

        v_loss1 = (vals - lambda_return)**2
        v_loss2 = (vals_clipped - lambda_return)**2
        loss_values = self.vf_coef * torch.max(v_loss1, v_loss2).mean()

        with torch.no_grad():
            approx_kl = (old_log_probs - log_probs).mean()

        return LossWithIntermediateLosses(loss_actions=loss_actions, loss_values=loss_values, loss_entropy=loss_entropy), approx_kl

    @torch.no_grad()
    def batch_rollout(self, envs, env_names):
        actions = []
        action_logits = []
        values = []
        rewards = []
        dones = []
        context_vectors = []

        envs.reset(self.critic[0].weight.device, env_names)
        cv = envs.get_context(option=2)
    
        for iter in range(self.ih):
            context_vectors.append(cv)
            als, vals = self.get_output(cv, env_names)
            acts = Categorical(logits=als).sample()
            actions.append(acts)
            action_logits.append(als)
            values.append(vals)
            rs, ds = envs.step(acts)
            cv = envs.get_context(option=2)
            rewards.append(rs.reshape(-1, 1))
            dones.append(ds.reshape(-1, 1))
        _, vals = self.get_output(cv, env_names)
        values.append(vals)

        context_vectors = torch.stack(context_vectors, dim=1)
        actions = torch.stack(actions, dim=1)
        action_logits = torch.stack(action_logits, dim=1)
        values = torch.cat(values, dim=1)
        rewards=torch.cat(rewards, dim=1)
        dones=torch.cat(dones, dim=1)

        return (context_vectors, actions, action_logits, values, rewards, dones)