import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distributions
import einops
import copy
import numpy as np
from torch.cuda.amp import autocast

from .modules.functions_losses import SymLogTwoHotLoss
from .modules.attention_blocks import PositionwiseFeedForward, AttentionBlockKVCache

class EMAScalar():
    def __init__(self, decay) -> None:
        self.scalar = 0.0
        self.decay = decay

    def __call__(self, value):
        self.update(value)
        return self.get()

    def update(self, value):
        self.scalar = self.scalar * self.decay + value * (1 - self.decay)

    def get(self):
        return self.scalar


def percentile(x, percentage):
    flat_x = torch.flatten(x)
    kth = int(percentage*len(flat_x))
    per = torch.kthvalue(flat_x, kth).values
    return per


def calc_lambda_return(rewards, values, termination, gamma, lam, dtype=torch.float32):
    # Invert termination to have 0 if the episode ended and 1 otherwise
    inv_termination = (termination * -1) + 1

    batch_size, batch_length = rewards.shape[:2]
    # gae_step = torch.zeros((batch_size, ), dtype=dtype, device="cuda")
    gamma_return = torch.zeros((batch_size, batch_length+1), dtype=dtype, device="cuda")
    gamma_return[:, -1] = values[:, -1]
    for t in reversed(range(batch_length)):  # with last bootstrap
        gamma_return[:, t] = rewards[:, t] + \
                            gamma * inv_termination[:, t] * (1-lam) * values[:, t] + \
                            gamma * inv_termination[:, t] * lam * gamma_return[:, t+1]
    return gamma_return[:, :-1]


class Network(nn.Module): # TODO: have similar structure to reward/termination decoder
    # NOTE: This at least works to some extent
    def __init__(self, latent_dim, feat_dim, num_layers, num_heads, action_dims) -> None:
        super().__init__()
        # concatenate & FC mixing
        # self.stem = nn.Sequential( 
        #     nn.Linear(latent_dim+feat_dim, feat_dim, bias=False),
        #     nn.LayerNorm(feat_dim),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(feat_dim, feat_dim, bias=False),
        #     nn.LayerNorm(feat_dim)
        # )

        # add & norm mixing
        self.latent_dim = latent_dim
        self.feat_dim = feat_dim
        self.latent_proj = nn.Linear(latent_dim, feat_dim)
        self.norm = nn.LayerNorm(feat_dim)

        self.policy_token = nn.Parameter(torch.randn(feat_dim))
        self.value_token = nn.Parameter(torch.randn(feat_dim))
        self.attention_blocks = nn.ModuleList([
            AttentionBlockKVCache(feat_dim=feat_dim, hidden_dim=feat_dim*2, num_heads=num_heads, dropout=0.1) for _ in range(num_layers)
        ])
        self.policy_head = nn.Linear(feat_dim, sum(action_dims))
        self.value_head = nn.Linear(feat_dim, 255)
    
    def _forward(self, feat):
        batch_size, batch_length = feat.shape[:2] # B L Obj D
        
        # feat = self.stem(feat)

        latent = feat[:, :, :, :self.latent_dim]
        hidden = feat[:, :, :, self.latent_dim:]
        latent = self.latent_proj(latent)
        feat = hidden + latent
        feat = self.norm(feat)

        policy_token = einops.repeat(self.policy_token, "D -> B L 1 D", B=batch_size, L=batch_length)
        value_token = einops.repeat(self.value_token, "D -> B L 1 D", B=batch_size, L=batch_length)
        feat = torch.cat([policy_token, value_token, feat], dim=-2) # -> B L Obj+2 D
        feat = einops.rearrange(feat, "B L P_V_Obj D -> (B L) P_V_Obj D")
        for attention_block in self.attention_blocks:
            feat, _ = attention_block(feat, feat, feat)
        feat = einops.rearrange(feat, "(B L) P_V_Obj D -> B L P_V_Obj D", B=batch_size)

        policy_token = feat[:, :, 0]
        value_token = feat[:, :, 1]
        return policy_token, value_token

    def forward(self, feat):
        policy_token, value_token = self._forward(feat)
        action_logits = self.policy_head(policy_token)
        value = self.value_head(value_token) # TODO: this name should be value logits?
        return action_logits, value

    def policy(self, feat):
        policy_token, value_token = self._forward(feat)
        action_logits = self.policy_head(policy_token)
        return action_logits
    
    def value(self, feat):
        policy_token, value_token = self._forward(feat)
        value = self.value_head(value_token)
        return value

class ActorCritic(nn.Module):
    def __init__(self, latent_dim, feat_dim, num_layers, num_heads, action_dims, gamma, lambd, entropy_coef) -> None:
        super().__init__()
        assert self.check_action_dims(action_dims), "Currently only support Atari like action space = [A_dim] or Hollow Knight like action space = [2]*A_dim"
        self.action_dims = action_dims
        self.action_choices = action_dims[0]

        self.gamma = gamma
        self.lambd = lambd
        self.entropy_coef = entropy_coef
        self.clip_coef = 0.2
        self.use_amp = True
        self.tensor_dtype = torch.bfloat16 if self.use_amp else torch.float32

        self.symlog_twohot_loss = SymLogTwoHotLoss(255, -20, 20)

        self.network = Network(latent_dim=latent_dim,
                               feat_dim=feat_dim,
                               num_layers=num_layers,
                               num_heads=num_heads,
                               action_dims=action_dims)
        self.slow_network = copy.deepcopy(self.network)

        self.lowerbound_ema = EMAScalar(decay=0.99)
        self.upperbound_ema = EMAScalar(decay=0.99)

        # self.optimizer = torch.optim.Adam(self.parameters(), lr=3e-5, eps=1e-5)
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=3e-5, weight_decay=0.1)
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
    
    def check_action_dims(self, action_dims):
        # currently only support Atari like action space = [A_dim] or Hollow Knight like action space = [2]*A_dim
        if len(action_dims) == 1:
            return True
        for value in action_dims:
            if value != action_dims[0]:
                return False
        return True

    @torch.no_grad()
    def update_slow_critic(self, decay=0.98):
        for slow_param, param in zip(self.slow_network.parameters(), self.network.parameters()):
            slow_param.data.copy_(slow_param.data * decay + param.data * (1 - decay))

    def policy(self, x):
        logits = self.network.policy(x)
        logits = einops.rearrange(logits, "B L (A_dim A_choices) -> B L A_dim A_choices", A_choices=self.action_choices)
        return logits
    
    def value(self, x):
        value = self.network.value(x)
        value = self.symlog_twohot_loss.decode(value)
        return value

    @torch.no_grad()
    def slow_value(self, x):
        self.slow_network.eval()
        value = self.slow_network.value(x)
        value = self.symlog_twohot_loss.decode(value)
        return value

    def get_logits_raw_value(self, x):
        logits, raw_value = self.network(x)
        logits = einops.rearrange(logits, "B L (A_dim A_choices) -> B L A_dim A_choices", A_choices=self.action_choices)
        return logits, raw_value

    @torch.no_grad()
    def sample(self, latent, greedy=False):
        self.eval()

        # latent = einops.rearrange(latent, "B L Obj D -> B L (Obj D)")
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp):
            logits = self.policy(latent)
            dist = distributions.Categorical(logits=logits)
            if greedy:
                action = dist.probs.argmax(dim=-1)
            else:
                action = dist.sample()
        return action
    
    @torch.no_grad()
    def sample_with_log_prob(self, latent, temperature=1.0):
        self.eval()

        # latent = einops.rearrange(latent, "B L Obj D -> B L (Obj D)")
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp):
            logits = self.policy(latent)
            dist = distributions.Categorical(logits=logits/temperature)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            log_prob = einops.reduce(log_prob, "B L A_dim -> B L", "sum")
        return action, log_prob

    def sample_as_env_action(self, latent, greedy=False) -> np.ndarray:
        action = self.sample(latent, greedy)
        return action.detach().cpu().numpy()

    def update(self, latent, action, old_log_prob, reward, termination, logger=None):
        '''
        Update policy and value model
        '''
        self.train()
        
        # latent = einops.rearrange(latent, "B L Obj D -> B L (Obj D)")
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp):
            logits, raw_value = self.get_logits_raw_value(latent)
            dist = distributions.Categorical(logits=logits[:, :-1])
            log_prob = dist.log_prob(action)
            log_prob = einops.reduce(log_prob, "B L A_dim -> B L", "sum")
            entropy = dist.entropy()
            entropy = einops.reduce(entropy, "B L A_dim -> B L", "sum")

            # decode value, calc lambda return
            slow_value = self.slow_value(latent)
            slow_lambda_return = calc_lambda_return(reward, slow_value, termination, self.gamma, self.lambd)
            value = self.symlog_twohot_loss.decode(raw_value)
            lambda_return = calc_lambda_return(reward, value, termination, self.gamma, self.lambd)

            # update value function with slow critic regularization
            value_loss = self.symlog_twohot_loss(raw_value[:, :-1], lambda_return.detach())
            slow_value_regularization_loss = self.symlog_twohot_loss(raw_value[:, :-1], slow_lambda_return.detach())

            lower_bound = self.lowerbound_ema(percentile(lambda_return, 0.05))
            upper_bound = self.upperbound_ema(percentile(lambda_return, 0.95))
            S = upper_bound-lower_bound
            norm_ratio = torch.max(torch.ones(1).cuda(), S)  # max(1, S) in the paper
            norm_advantage = (lambda_return-value[:, :-1]) / norm_ratio

            # original on-policy loss
            # policy_loss = -(log_prob * norm_advantage.detach()).mean() 

            # off-policy loss
            log_ratio = log_prob - old_log_prob
            ratio = torch.exp(log_ratio)
            policy_loss = -ratio * norm_advantage.detach()
            policy_loss = policy_loss.mean()

            # PPO
            # log_ratio = log_prob - old_log_prob
            # ratio = torch.exp(log_ratio)
            # policy_loss1 = -ratio * norm_advantage.detach()
            # policy_loss2 = -torch.clamp(ratio, 1-self.clip_coef, 1+self.clip_coef) * norm_advantage.detach()
            # policy_loss = torch.max(policy_loss1, policy_loss2).mean()

            entropy_loss = entropy.mean()

            loss = policy_loss + value_loss + slow_value_regularization_loss - self.entropy_coef * entropy_loss

        # gradient descent
        self.scaler.scale(loss).backward()
        self.scaler.unscale_(self.optimizer)  # for clip grad
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=100.0)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad(set_to_none=True)

        self.update_slow_critic()

        if logger is not None:
            logger.log('ActorCritic/policy_loss', policy_loss.item())
            logger.log('ActorCritic/value_loss', value_loss.item())
            logger.log('ActorCritic/entropy_loss', entropy_loss.item())
            logger.log('ActorCritic/S', S.item())
            logger.log('ActorCritic/norm_ratio', norm_ratio.item())
            logger.log('ActorCritic/total_loss', loss.item())