import torch
import time
from abc import ABC
from easydict import EasyDict
import numpy as np
from src.chpo.Chpo_policy_utils import Adam
from collections import namedtuple
from typing import Optional, Tuple
from torch.distributions import Independent, Normal, Categorical

import os
import gym
import csv
import wandb
import gym_hybrid
from datetime import datetime
from src.chpo.Chpo_model import RunningMeanStd


def gae(value, next_value, reward, done, gamma: float = 0.99, lambda_: float = 0.95) -> torch.FloatTensor:

    
    done = done.float()
    traj_flag = done
    if len(value.shape) == len(reward.shape) + 1:  # for some marl case: value(T, B, A), reward(T, B)
        reward = reward.unsqueeze(-1)
        done = done.unsqueeze(-1)
        traj_flag = traj_flag.unsqueeze(-1)

    next_value *= (1 - done)
    delta = reward + gamma * next_value - value
    factor = gamma * lambda_ * (1 - traj_flag)
    adv = torch.zeros_like(value)
    gae_item = torch.zeros_like(value[0])

    for t in reversed(range(reward.shape[0])):
        gae_item = delta[t] + factor[t] * gae_item
        adv[t] = gae_item
    return adv

class CHPOPolicy(ABC):
    def __init__(
            self,
            env_id, 
            buf,
            model,
            device,
            chpo_param_init = True, 
            learning_rate: float = 3e-4,
            grad_clip_type='clip_norm',
            grad_clip_value=0.5,
            gamma = 0.99, 
            gae_lambda = 0.95, 
            recompute_adv = True,
            value_weight = 0.5,
            entropy_weight = 0.5,
            # clip_ratio = 0.05,
            clip_ratio = 0.2,
            adv_norm = True,
            value_norm = True,
            wandb_flag = False,
            env = None,
            share_encoder = True,
            batch_size = 320,
            cost_limit = 1.0,
            cost_distance = 0.0,
            seed = 0,
            rc_ratio = 2,
    )-> None:
        self._model = model.to(device)
        self._env_id = env_id
        self._buf = buf
        self._device = device
        self._chpo_param_init = chpo_param_init
        self._learning_rate = learning_rate
        self._grad_clip_type = grad_clip_type
        self._grad_clip_value = grad_clip_value
        
        self._gamma = gamma
        self._gae_lambda = gae_lambda
        self._recompute_adv = recompute_adv
        self._value_weight = value_weight
        self._entropy_weight = entropy_weight
        self._clip_ratio = clip_ratio
        self._adv_norm = adv_norm
        self._value_norm = value_norm
        self._wandb_flag = wandb_flag
        self._env = env
        self._share_encoder = share_encoder
        self.batch_size = batch_size
        self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
        self.cost_running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
        self.cost_limit = cost_limit
        self.cost_distance = cost_distance
        self.cost_update = 0
        self.reward_update = 0
        self._reward_cost_update = 0
        self.rc_ratio = rc_ratio
        
        self._init_cost_update = True
        
        directory = f'{env_id}/ration_{self.rc_ratio}_limit_{self.cost_limit}_dis_{self.cost_distance}'
        eval_csv_name = f'seed_{seed}.csv'
        
        if not os.path.exists(directory):
            os.makedirs(directory)
            
        self.filepath = os.path.join(directory, eval_csv_name)
        
        with open(self.filepath, 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['step',"eval_ret", "eval_cost"])
            
        # Init the model of the CHPO network
        if self._chpo_param_init:
            for n, m in self._model.named_modules():
                if isinstance(m, torch.nn.Linear):
                    torch.nn.init.orthogonal_(m.weight)
                    torch.nn.init.zeros_(m.bias)

            for m in list(self._model.critic.modules()) + list(self._model.actor.modules()):
                if isinstance(m, torch.nn.Linear):
                    # orthogonal initialization
                    torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                    torch.nn.init.zeros_(m.bias)

            for m in self._model.actor.modules():
                if isinstance(m, torch.nn.Linear):
                    torch.nn.init.zeros_(m.bias)
                    m.weight.data.copy_(0.01 * m.weight.data)

        # Optimizer
        if self._share_encoder:
            self._optimizer = Adam(
                self._model.parameters(),
                lr=self._learning_rate,
                grad_clip_type=self._grad_clip_type,
                clip_value=self._grad_clip_value
            )
        else:
            self._actor_optimizer = Adam(
                self._model.actor.parameters(),
                lr=self._learning_rate,
                grad_clip_type=self._grad_clip_type,
                clip_value=self._grad_clip_value
            )
            self._critic_optimizer = Adam(
                self._model.critic.parameters(),
                lr=self._learning_rate,
                grad_clip_type=self._grad_clip_type,
                clip_value=self._grad_clip_value
            )
            self._cost_critic_optimizer = Adam(
                self._model.cost_critic.parameters(),
                lr=self._learning_rate,
                grad_clip_type=self._grad_clip_type,
                clip_value=self._grad_clip_value
            )


    def update(self, data, train_iters, train_epoch, train_all_epoch, ep_len)-> None:
        """
        Overview:
            Given training data, implement network update for one iteration and update related variables.
            Learner's API for serial entry.
            Also called in ``start`` for each iteration's training.
        Arguments:
            - data (:obj:`dict`): Training data which is retrieved from repaly buffer.

        .. note::

            ``_policy`` must be set before calling this method.

            ``_policy.forward`` method contains: forward, backward, grad sync(if in multi-gpu mode) and
            parameter update.

            ``before_iter`` and ``after_iter`` hooks are called at the beginning and ending.
        """
        # train_iters = 10
        value_weight = 0.5
        # entropy_weight = 0.2
        entropy_weight = 0.5
        clip_ratio = 0.2
        
        batch_iters = 10

        # batch = data.get()
        batch = data.get()
        


        for epoch in range(train_iters):

            
            with torch.no_grad():
                #reward
                value = self._model.compute_critic(batch['obs'])['value']
                next_value = self._model.compute_critic(batch['next_obs'])['value']
                
                #value norm
                if self._value_norm:
                    value *= self._running_mean_std.std
                    next_value *= self._running_mean_std.std
                
                batch['adv']=gae(value,next_value,batch['reward'],batch['done'])
                
                unnormalized_returns = value + batch['adv']
                
                if self._value_norm:
                    batch['value'] = value / self._running_mean_std.std
                    batch['ret'] = unnormalized_returns / self._running_mean_std.std
                    self._running_mean_std.update(unnormalized_returns.cpu().numpy())
                else:
                    batch['value'] = value 
                    batch['ret'] = unnormalized_returns 
 
                #cost
                cost_value = self._model.compute_cost_critic(batch['obs'])['value']
                next_cost_value = self._model.compute_cost_critic(batch['next_obs'])['value']
                
                #cost value norm
                if self._value_norm:
                    cost_value *= self.cost_running_mean_std.std
                    next_cost_value *= self.cost_running_mean_std.std
                
                
                batch['adc']=gae(cost_value,next_cost_value,batch['cost'],batch['done'])
                
                unnormalized_cost_returns = cost_value + batch['adc']
                
                if self._value_norm:
                    batch['cost_value'] = cost_value / self.cost_running_mean_std.std
                    batch['cost_ret'] = unnormalized_cost_returns / self.cost_running_mean_std.std
                    self.cost_running_mean_std.update(unnormalized_cost_returns.cpu().numpy())
                else:
                    batch['cost_value'] = cost_value 
                    batch['cost_ret'] = unnormalized_cost_returns 
                
                
            for i in range(batch_iters):
                batch_train = dict(
                    obs=batch['obs'][i*self.batch_size:(i+1)*self.batch_size,],
                    next_obs=batch['next_obs'][i*self.batch_size:(i+1)*self.batch_size,],
                    discrete_act=batch['discrete_act'][i*self.batch_size:(i+1)*self.batch_size,],
                    parameter_act=batch['parameter_act'][i*self.batch_size:(i+1)*self.batch_size,],
                    reward=batch['reward'][i*self.batch_size:(i+1)*self.batch_size,],
                    ret=batch['ret'][i*self.batch_size:(i+1)*self.batch_size,],
                    adv=batch['adv'][i*self.batch_size:(i+1)*self.batch_size,],
                    value=batch['value'][i*self.batch_size:(i+1)*self.batch_size,],
                    logp_discrete_act=batch['logp_discrete_act'][i*self.batch_size:(i+1)*self.batch_size,],
                    logp_parameter_act=batch['logp_parameter_act'][i*self.batch_size:(i+1)*self.batch_size,],
                    done=batch['done'][i*self.batch_size:(i+1)*self.batch_size,],
                    logit_action_type=batch['logit_action_type'][i*self.batch_size:(i+1)*self.batch_size,],
                    logit_action_argsmu=batch['logit_action_argsmu'][i*self.batch_size:(i+1)*self.batch_size,],
                    logit_action_argssigma=batch['logit_action_argssigma'][i*self.batch_size:(i+1)*self.batch_size,],
                    cost=batch['cost'][i*self.batch_size:(i+1)*self.batch_size,],
                    cost_ret=batch['cost_ret'][i*self.batch_size:(i+1)*self.batch_size,],
                    adc=batch['adc'][i*self.batch_size:(i+1)*self.batch_size,],
                    cost_value=batch['cost_value'][i*self.batch_size:(i+1)*self.batch_size,],
                    epret_cost=batch['epret_cost'][i*self.batch_size:(i+1)*self.batch_size,],
                )
                output = self._model.compute_actor_critic(batch_train['obs'])
                
                batch_train_epcost = 0
                
                for j in range(self.batch_size):
                    batch_train_epcost = batch_train_epcost + batch_train['cost'][j]
                    
                batch_train_epcost_judge = (batch_train_epcost/self.batch_size)*ep_len
                
                adv = batch_train['adv']
                if self._adv_norm:
                    adv = (adv - adv.mean()) / (adv.std() + 1e-8)
                
                adc = batch_train['adc']
                if self._adv_norm:
                    adc = (adc - adc.mean()) / (adc.std() + 1e-8)
                    
                
                if batch_train_epcost_judge > (self.cost_limit + self.cost_distance) and train_epoch > 0.5*train_all_epoch:
                    if self._init_cost_update:
                        adv_final = -adc 
                        self.cost_update = self.cost_update + 1
                        self._init_cost_update = False
                    else:
                        if self.cost_update < (self._reward_cost_update/self.rc_ratio):
                            adv_final = -adc 
                            self.cost_update = self.cost_update + 1
                        else:
                            adv_final = adv
                            self.reward_update = self.reward_update + 1
                            if self.cost_update > 0:
                                self._reward_cost_update +=1
                else:
                    self.reward_update = self.reward_update + 1
                    adv_final = adv
                    if self.cost_update > 0:
                        self._reward_cost_update +=1
                            
                #discrete loss
                discrete_weight = torch.ones_like(adv_final)
                dist_discrete_new = torch.distributions.categorical.Categorical(logits=output['logit']['action_type'])
                dist_discrete_old = torch.distributions.categorical.Categorical(logits=batch_train['logit_action_type'])
                logp_discrete_new = dist_discrete_new.log_prob(batch_train['discrete_act'])
                logp_discrete_old = dist_discrete_old.log_prob(batch_train['discrete_act'])
                dist_discrete_new_entropy = dist_discrete_new.entropy()
                if dist_discrete_new_entropy.shape != discrete_weight.shape:
                    dist_discrete_new_entropy = dist_discrete_new.entropy().mean(dim=1)
                discrete_entropy_loss = (dist_discrete_new_entropy*discrete_weight).mean()
                discrete_ratio = torch.exp(logp_discrete_new-logp_discrete_old)
                if discrete_ratio.shape !=adv_final.shape:
                    discrete_ratio = discrete_ratio.mean(dim=1)
                discrete_surr1 = discrete_ratio * adv_final
                discrete_surr2 = discrete_ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv_final
                discrete_policy_loss = (-torch.min(discrete_surr1, discrete_surr2) * discrete_weight).mean()
                with torch.no_grad():
                    dis_approx_kl = (logp_discrete_old - logp_discrete_new).mean().item()
                    dis_clipped = discrete_ratio.gt(1 + clip_ratio) | discrete_ratio.lt(1 - clip_ratio)
                    dis_clipfrac = torch.as_tensor(dis_clipped).float().mean().item()
                    
                #continuous loss
                args_weight = torch.ones_like(adv_final)
                dist_args_new = Independent(Normal(output['logit']['action_args']['mu'], output['logit']['action_args']['sigma']), 1)
                dist_args_old = Independent(Normal(batch_train['logit_action_argsmu'], batch_train['logit_action_argssigma']), 1)
                logp_args_new = dist_args_new.log_prob(batch_train['parameter_act'])
                logp_args_old = dist_args_old.log_prob(batch_train['parameter_act'])
                args_entropy_loss = (dist_args_new.entropy() * args_weight).mean()
                args_ratio = torch.exp(logp_args_new - logp_args_old)
                if args_ratio.shape !=adv_final.shape:
                    args_ratio = args_ratio.mean(dim=1)
                args_surr1 = args_ratio * adv_final
                args_surr2 = args_ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv_final
                args_policy_loss = (-torch.min(args_surr1, args_surr2) * args_weight).mean()
                with torch.no_grad():
                    args_approx_kl = (logp_args_old - logp_args_new).mean().item()
                    args_clipped = args_ratio.gt(1 + clip_ratio) | args_ratio.lt(1 - clip_ratio)
                    args_clipfrac = torch.as_tensor(args_clipped).float().mean().item()
                    
                #value loss
                value_clip = batch_train['value'] + (output['value'] - batch_train['value']).clamp(-clip_ratio, clip_ratio)
                v1 = (batch_train['ret'] - output['value']).pow(2)
                v2 = (batch_train['ret'] - value_clip).pow(2)
                value_loss = 0.5 * (torch.max(v1, v2) * args_weight).mean()
                
                #cost value loss 
                cost_value_clip = batch_train['cost_value'] + (output['cost_value'] - batch_train['cost_value']).clamp(-clip_ratio, clip_ratio)
                c1 = (batch_train['cost_ret'] - output['cost_value']).pow(2)
                c2 = (batch_train['cost_ret'] - cost_value_clip).pow(2)
                cost_value_loss = 0.5 * (torch.max(c1, c2) * args_weight).mean()
                
                
                if self._share_encoder:

                    total_loss = discrete_policy_loss + args_policy_loss + value_weight*value_loss + value_weight*cost_value_loss - entropy_weight*(discrete_entropy_loss+args_entropy_loss)
                    
                    self._optimizer.zero_grad()
                    total_loss.backward()
                    self._optimizer.step()
                else:
                    
                    actor_loss = discrete_policy_loss + args_policy_loss - entropy_weight*(discrete_entropy_loss+args_entropy_loss)
                    critic_loss = value_loss
                    cost_critic_loss = cost_value_loss
                    
                    self._actor_optimizer.zero_grad()
                    self._critic_optimizer.zero_grad()
                    self._cost_critic_optimizer.zero_grad()
                    
                    actor_loss.backward()
                    critic_loss.backward()
                    cost_critic_loss.backward()
                    
                    self._actor_optimizer.step()
                    self._critic_optimizer.step()
                    self._cost_critic_optimizer.step()
                
                
                
                if self._wandb_flag:
                    wandb.log({'record/discrete_policy_loss': discrete_policy_loss.item(), 
                               'record/args_policy_loss': args_policy_loss.item(), 
                               'record/value_loss':value_loss.item(),
                               'record/discrete_entropy_loss': discrete_entropy_loss.item(),
                               'record/args_entropy_loss': args_entropy_loss.item(),
                               'record/dis_approx_kl:': dis_approx_kl,
                               'record/dis_clipfrac:': dis_clipfrac,
                               'record/args_approx_kl:': args_approx_kl,
                               'record/args_clipfrac:': args_clipfrac,
                               'record/cost_value_loss':cost_value_loss.item(),
                               'record/reward_update':self.reward_update,
                               'record/cost_update':self.cost_update,
                               'record/reward_cost_update':self._reward_cost_update,
                               'record/cost_ret_max':batch_train['cost_ret'].max().item(),
                               'record/cost_ret_mean':batch_train['cost_ret'].mean().item(),
                            })
        

    def rollout(self, steps_per_epoch)-> None:
        '''
        Overview:
            Roll out to collect the sample and store to the buffer.
        Arguments:
            - env_id: the environment id. 
        '''
        # local_steps_per_epoch = 1000
        
        # env = gym.make(self._env_id)
        env = self._env

        # Prepare for interaction with environment
        start_time = time.time()
        obs, ep_ret, ep_len, ep_cost = env.reset(), 0, 0, 0
        
        ep_num = 0
        ep_mean_len = 0

        for t in range(steps_per_epoch):
            
                
                # Get the discrete and parameters action, 
                with torch.no_grad():
                    state = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(self._device)
                    # get the logit of the action throught CHPO network
                    action_value = self._model.compute_actor_critic(state)

                    logit = action_value['logit']
                    value = action_value['value']
                    cost_value = action_value['cost_value']
                    # discrete action
                    action_type_logit = logit['action_type']
                    prob = torch.softmax(action_type_logit, dim=-1)   # This 
                    pi_action = Categorical(prob)
                    action_type = pi_action.sample()
                    
                    log_prob_action_type = pi_action.log_prob(action_type)
                    # continuous action
                    mu, sigma = logit['action_args']['mu'], logit['action_args']['sigma']
                    dist = Independent(Normal(mu, sigma), 1)
                    
                    # dist = Normal(mu, sigma)
                    action_args = dist.sample()
                    # print('action_args:', action_args)
                    log_prob_action_args = dist.log_prob(action_args)

                    action = (int(action_type.cpu().numpy()), action_args.cpu().float().numpy().flatten())


                # interaction with the environment
                next_obs, reward, cost, done, info = env.step(action)

                ep_ret += reward
                ep_cost += cost
                ep_len += 1

                # Store the sample to the buffer.
                
                if self._device == 'cuda:0':
                    action_args = action_args.cpu().float().numpy().flatten()
                    log_prob_action_type = log_prob_action_type.cpu().float().numpy().flatten()
                    log_prob_action_args = log_prob_action_args.cpu().float().numpy().flatten()
                    logit['action_type'] = logit['action_type'].cpu().float().numpy().flatten()
                    logit['action_args']['mu'] = logit['action_args']['mu'].cpu().float().numpy().flatten()
                    logit['action_args']['sigma'] = logit['action_args']['sigma'].cpu().float().numpy().flatten()
                
                self._buf.store(
                    obs = obs,
                    next_obs = next_obs,
                    discrete_act = action_type,
                    parameter_act = action_args,
                    rew = reward,
                    val = value,
                    logp_discrete_act = log_prob_action_type,
                    logp_parameter_act = log_prob_action_args,
                    done = done,
                    logit_action_type = logit['action_type'],
                    logit_action_argsmu = logit['action_args']['mu'],
                    logit_action_argssigma = logit['action_args']['sigma'],
                    cost = cost,
                    cost_value = cost_value
                )

                if self._wandb_flag:
                    wandb.log({'reward/rew': reward, 
                                'reward/value':value,
                              })

                # Update the obs
                obs = next_obs

                # The stop condition for each epoch
                epoch_ended = t==steps_per_epoch-1

                # The trajectory or epoch is stop
                if done or epoch_ended:
                    if reward!=0 and cost!=0:
                        value = reward
                        cost_value = cost
                    elif reward!=0 and cost==0:
                        value = reward
                        with torch.no_grad():
                            state = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(self._device)
                            action_value = self._model.compute_actor_critic(state)  
                        cost_value = action_value['cost_value'].cpu().float().numpy()
                    elif reward == 0 and cost!=0:
                        cost_value = cost
                        with torch.no_grad():
                            state = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(self._device)
                            action_value = self._model.compute_actor_critic(state)  
                        value = action_value['value'].cpu().float().numpy()  
                    else:
                        with torch.no_grad():
                            state = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(self._device)
                            # get the logit of the action throught CHPO network
                            action_value = self._model.compute_actor_critic(state)  
                        value = action_value['value'].cpu().float().numpy()
                        cost_value = action_value['cost_value'].cpu().float().numpy()       

                    self._buf.finish_path(value,cost_value)
                    
                    ep_mean_len = ep_len + ep_mean_len
                    ep_num = ep_num + 1
                    

                    if self._wandb_flag:
                        wandb.log({'ep_ret': ep_ret, 
                                   'ep_len':ep_len,
                                   'ep_cost':ep_cost})


                    obs, ep_ret, ep_len, ep_cost = env.reset(), 0, 0, 0
                    
        return ep_mean_len/ep_num
                    

    def evaluate(self, eval_epoch=5, step=0):
        '''
        Eval the model 
        '''
        print('=====evaluate======')
        success_ratio = 0
        env = self._env
        mean_ep_ret, mean_ep_len, mean_ep_cost = 0, 0, 0

        obs, ep_ret, ep_len, ep_cost = env.reset(), 0, 0, 0

        eval_epoch_step = eval_epoch

        # for t in range(eval_epoch):
        while eval_epoch_step > 0:                
            
            # Get the discrete and parameters action, 
            with torch.no_grad():
                state = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(self._device)
                # get the logit of the action throught CHPO network
                action_value = self._model.compute_actor(state)

                logit = action_value['logit']
                # discrete action
                action_type_logit = logit['action_type']
                prob = torch.softmax(action_type_logit, dim=-1)        # This
                action_type = torch.argmax(prob, dim=1, keepdim=True) 
                # print('action_type:', action_type)

                # continuous action
                mu, _ = logit['action_args']['mu'], logit['action_args']['sigma']
                # print('mu:', mu)
                action_args = mu


                action = (int(action_type.cpu().numpy()), action_args.cpu().float().numpy().flatten())

            # interaction with the environment
            next_obs, reward, cost, done, info = env.step(action)
            
            ep_ret += reward
            ep_cost += cost
            ep_len += 1

            # Update the obs
            obs = next_obs

            # The trajectory or epoch is stop
            if done :
                ''' 
                Record the data.
                '''
                mean_ep_ret += ep_ret
                mean_ep_cost += ep_cost
                mean_ep_len += ep_len

                if ep_ret>1.5:
                    success_ratio += 1

                if self._wandb_flag:
                    wandb.log({'eval_ep_ret': ep_ret, 'eval_ep_len': ep_len, 'eval_ep_cost': ep_cost})                

                obs, ep_ret, ep_len, ep_cost = env.reset(), 0, 0, 0
                eval_epoch_step -= 1
                

        if self._wandb_flag:
            wandb.log({'eval_mean_ep_ret': mean_ep_ret/eval_epoch, 'eval_mean_ep_len':mean_ep_len/eval_epoch, 'eval_mean_ep_cost':mean_ep_cost/eval_epoch})
            wandb.log({'eval_success_ratio': success_ratio/eval_epoch})
            
        with open(self.filepath, 'a', newline='') as file: 
            writer = csv.writer(file)
            writer.writerow([step,mean_ep_ret/eval_epoch,mean_ep_cost/eval_epoch])
        
        
            
            

    


        
        




    




                
                


























