import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distributions
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import copy
from torch.cuda.amp import autocast 

from sub_models.functions_losses import SymLogTwoHotLoss
from utils import EMAScalar


def percentile(x, percentage):
    flat_x = torch.flatten(x)
    kth = int(percentage*len(flat_x))
    per = torch.kthvalue(flat_x, kth).values
    return per


def calc_lambda_return(rewards, values, termination, gamma, lam, device, dtype=torch.float32):
    # Invert termination to have 0 if the episode ended and 1 otherwise
    inv_termination = (termination * -1) + 1

    batch_size, batch_length = rewards.shape[:2]
    # gae_step = torch.zeros((batch_size, ), dtype=dtype, device="cuda")
    gamma_return = torch.zeros((batch_size, batch_length+1), dtype=dtype, device=device)
    gamma_return[:, -1] = values[:, -1]
    for t in reversed(range(batch_length)):  # with last bootstrap
        gamma_return[:, t] = \
            rewards[:, t] + \
            gamma * inv_termination[:, t] * (1-lam) * values[:, t] + \
            gamma * inv_termination[:, t] * lam * gamma_return[:, t+1]
    return gamma_return[:, :-1]


class ActorCriticAgent(nn.Module):
    def __init__(self, feat_dim, env_names, n_clusters, num_layers, hidden_dim, action_dim, gamma, lambd, entropy_coef) -> None:
        super().__init__()
        self.gamma = gamma
        self.lambd = lambd
        self.entropy_coef = entropy_coef
        self.use_amp = True
        self.tensor_dtype = torch.bfloat16 if self.use_amp else torch.float32
        self.env_names = env_names
        self.n_clusters = n_clusters
        self.symlog_twohot_loss = SymLogTwoHotLoss(255, -20, 20)

        actor = [
            nn.Linear(feat_dim, hidden_dim, bias=False),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        ]
        for i in range(num_layers - 1):
            actor.extend([
                nn.Linear(hidden_dim, hidden_dim, bias=False),
                nn.LayerNorm(hidden_dim),
                nn.ReLU()
            ])
        self.actor = nn.Sequential(
            *actor,
            nn.Linear(hidden_dim, action_dim)
        )
 
        self.critic = nn.ModuleList()
        self.slow_critic = nn.ModuleList()

        for i in range(self.n_clusters): 
            critic_layers = []
            critic_layers.extend([
                nn.Linear(feat_dim, hidden_dim, bias=False),
                nn.LayerNorm(hidden_dim),
                nn.ReLU()
            ])
            
            for _ in range(num_layers - 1):
                critic_layers.extend([
                    nn.Linear(hidden_dim, hidden_dim, bias=False),
                    nn.LayerNorm(hidden_dim),
                    nn.ReLU()
                ]) 
            critic_layers.append(nn.Linear(hidden_dim, 255))  
            critic_net = nn.Sequential(*critic_layers)
            self.critic.append(critic_net)                   
            self.slow_critic.append(copy.deepcopy(critic_net))


        self.lowerbound_ema = {} 
        self.upperbound_ema = {}#nn.ModuleDict()
        for env in self.env_names: 
            self.lowerbound_ema[env] = EMAScalar(decay=0.99)
            self.upperbound_ema[env] = EMAScalar(decay=0.99)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=3e-5, eps=1e-5) 
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

    @torch.no_grad()
    def update_slow_critic(self, decay=0.98):
        for i in range(self.n_clusters):  
            for slow_param, param in zip(self.slow_critic[i].parameters(), self.critic[i].parameters()):
                slow_param.data.copy_(slow_param.data * decay + param.data * (1 - decay))

    def policy(self, x):
        logits = self.actor(x)
        return logits

    def value(self, x, env):
        cluster_id = self.get_cluster_id(env)
        value = self.critic[cluster_id](x)
        value = self.symlog_twohot_loss.decode(value)
        return value

    @torch.no_grad()
    def slow_value(self, x, env):
        cluster_id = self.get_cluster_id(env)
        value = self.slow_critic[cluster_id](x)
        value = self.symlog_twohot_loss.decode(value)
        return value
    
    def get_cluster_id(self, env): 
        index = self.env_names.index(env) 
        label = int(self.cluster_ids[index]) 
        return label 
    
    def initial_cluster(self, cluster_labels):
        self.cluster_ids = cluster_labels 
    def initial_cluster_index_eval(self, cluster_indices):
        self.cluster_ids = cluster_indices 

    def get_logits_raw_value(self, x, env):
        cluster_id = self.get_cluster_id(env)
        logits = self.actor(x)
        raw_value = self.critic[cluster_id](x)
        return logits, raw_value

    @torch.no_grad()
    def sample(self, latent, greedy=False):
        self.eval()
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp):
            logits = self.policy(latent)
            dist = distributions.Categorical(logits=logits)
            if greedy:
                action = dist.probs.argmax(dim=-1)
            else:
                action = dist.sample()
        return action

    def sample_as_env_action(self, latent, greedy=False):
        action = self.sample(latent, greedy)
        return action.detach().cpu().squeeze(-1).numpy()
    
    def forward(self, rank_env, latent, action, old_logprob, old_value, reward, termination, logger=None):
        '''
        Update policy and value model
        '''
        self.train()
        model_device = next(self.parameters()).device 
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp):
            loss = {}
            loss_item = torch.tensor(0.0,device=model_device)  
            for env_name in rank_env: 
                logits, raw_value = self.get_logits_raw_value(latent[env_name], env_name)
                dist = distributions.Categorical(logits=logits[:, :-1])
                log_prob = dist.log_prob(action[env_name])
                entropy = dist.entropy()

                # decode value, calc lambda return
                slow_value = self.slow_value(latent[env_name], env_name)
            
                slow_lambda_return = calc_lambda_return(reward[env_name], slow_value, termination[env_name], self.gamma, self.lambd, model_device)
                
                value = self.symlog_twohot_loss.decode(raw_value)
                lambda_return = calc_lambda_return(reward[env_name], value, termination[env_name], self.gamma, self.lambd, model_device)
           
                value_loss = self.symlog_twohot_loss(raw_value[:, :-1], lambda_return.detach())
                slow_value_regularization_loss = self.symlog_twohot_loss(raw_value[:, :-1], slow_lambda_return.detach())
                
                lower_bound = self.lowerbound_ema[env_name](percentile(lambda_return, 0.05))
                upper_bound = self.upperbound_ema[env_name](percentile(lambda_return, 0.95))
                S = upper_bound-lower_bound 
          
                norm_ratio = torch.max(torch.ones(1).to(model_device), S)  # max(1, S) in the paper
                norm_advantage = (lambda_return-value[:, :-1]) / norm_ratio
                policy_loss = -(log_prob * norm_advantage.detach()).mean()

                entropy_loss = entropy.mean()

                loss[env_name] = policy_loss + value_loss + slow_value_regularization_loss - self.entropy_coef * entropy_loss
                loss_item+=loss[env_name]
                if logger is not None:
                    logger.log(f"{env_name}_ActorCritic/policy_loss", policy_loss.item())
                    logger.log(f"{env_name}_ActorCritic/value_loss", value_loss.item())
                    logger.log(f"{env_name}_ActorCritic/entropy_loss", entropy_loss.item())
                    logger.log(f"{env_name}_ActorCritic/S", S.item())
                    logger.log(f"{env_name}_ActorCritic/norm_ratio", norm_ratio.item())
                    logger.log(f"{env_name}_ActorCritic/total_loss", loss[env_name].item())  
        return loss_item, loss
