#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from models.model import NetHackPolicyNet, NetHackValueNet
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.distributions import Categorical

class PPO():
    def __init__(self, config) -> None:
        self.config = config
        # model
        super().__init__()

        # optimizer
        self.lr                = config.lr
        self.grad_clip         = config.max_grad_norm
        self.eps               = config.adam_eps
        self.model = NetHackPolicyNet(config)
        self.value_model = NetHackValueNet(config)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.lr, eps=self.eps)
        
        # loss
        self.clip              = config.clip_eps
        self.entropy_coef      = config.entropy_coef
        self.iter_with_ks      = config.iter_with_ks
        self.ks_coef           = config.kickstarting_coef_initial
        self.ks_coef_minimum   = config.kickstarting_coef_minimum
        self.ks_coef_descent   = config.kickstarting_coef_decent
        self.value_loss_coef   = config.value_loss_coef
        self.meta_controller_value_loss_coef   = 1
        self.meta_controller_pi_loss_coef = 1
        
        # other settings 
        self.batch_size        = config.batch_size
        self.epochs            = config.epoch
        self.num_worker        = config.num_worker
        self.iter              = 0
        self.eplison = 1e-3
        
    def __call__(self, obs, core_state):
        return self.model(obs, core_state)
        
    def update_kickstarting_coef(self):
        self.iter += 1
        if self.ks_coef <= self.ks_coef_minimum:
            self.ks_coef = self.ks_coef_minimum
        else:
            self.ks_coef -= self.ks_coef_descent
        # if self.iter == self.iter_with_ks:
        #     self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.lr, eps=self.eps)
        #     self.model.actor[-1].reset_parameters()

    def update_policy(self, buffer, llm4t = False):
        losses = []
        torch.autograd.set_detect_anomaly(True)
        for _ in range(self.epochs):
            for batch in buffer.sample(self.batch_size):
                obs_batch, _, action_batch, return_batch, advantage_batch, values_batch, mask, log_prob_batch, controller_prob_batch, _,  meta_controller_prob_batch, meta_controller_value_batch, meta_controller_tensor_batch, _, core_state = batch

                # get policy
                output, result = self.model(obs_batch, core_state)
                pdf = output["policy_logits"]
                value = output["baseline"][-1]

                # update
                entropy_loss = (pdf.entropy()).mean()
                if llm4t : 
                    kickstarting_loss = -(pdf.logits * meta_controller_prob_batch).sum(dim=-1).mean()
                else :
                    kickstarting_loss = -(pdf.logits).sum(dim=-1).mean()

                ratio = torch.exp(pdf.log_prob(action_batch) - log_prob_batch)
                surr1 = ratio * advantage_batch
                surr2 = torch.clamp(ratio, 1.0 - self.clip, 1.0 + self.clip) * advantage_batch * mask
                policy_loss = -torch.min(surr1, surr2).mean()

                value_clipped = values_batch + torch.clamp(value - values_batch, -self.clip, self.clip)
                surr1 = ((value - return_batch)*mask).pow(2)
                surr2 = ((value_clipped - return_batch)*mask).pow(2)
                value_loss = torch.max(surr1, surr2).mean()
                
                loss = policy_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_loss 
                if self.iter < self.iter_with_ks:
                    loss += self.ks_coef * kickstarting_loss
                

                # Update actor-critic
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
                self.optimizer.step()

                # Update logger
                losses.append([loss.item(), entropy_loss.item(), kickstarting_loss.item(), policy_loss.item(), value_loss.item()])
                
        if self.iter < self.iter_with_ks:
            # Update kickstarting coefficient
            self.update_kickstarting_coef()
        
        mean_losses = np.mean(losses, axis=0)
        return mean_losses
    
    def update_network(self, buffer, setting = -1):
        losses = []
        for _ in range(self.epochs):
            for batch in buffer.sample(self.batch_size):
                obs_batch, _, action_batch, return_batch, advantage_batch, values_batch, mask, log_prob_batch, _, controller_prob_batch, meta_controller_prob_batch, meta_controller_value_batch, meta_controller_tensor_batch, _, core_state= batch
                # get policy
                output, result = self.model(obs_batch, core_state)
                pdf = output["policy_logits"]
                value = output["baseline"][-1]

                # imitate q value
                value_dis_clipped = F.mse_loss(meta_controller_value_batch, values_batch)
                value_dis_clipped = torch.clamp(value_dis_clipped, -self.clip, self.clip)
                value_dis_loss = value_dis_clipped.mean()

                #imitate pi
                controller_logits = pdf.logits
                kl_div_per_sample = F.cross_entropy(controller_logits, meta_controller_tensor_batch, reduction='mean')
                average_kl_div = kl_div_per_sample.mean()

                if setting == 1 :
                    self.meta_controller_value_loss_coef = 0
                if setting == 2 :
                    self.meta_controller_pi_loss_coef = 0

                loss = (self.meta_controller_value_loss_coef * value_dis_loss + self.meta_controller_pi_loss_coef * average_kl_div) * 5

                # Update actor-critic
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
                self.optimizer.step()

                # Update logger
                losses.append([loss.item(), value_dis_loss.item(), average_kl_div.item()])

        mean_losses = np.mean(losses, axis=0)
        return mean_losses



# class Critic():
#     def __init__(self, obs_space, action_space):
#         self.obs_space = obs_space
#         self.action_space = action_space
#         self.critic_network = Critic_network(self.obs_space, self.action_space)
#         self.gamma = 0.9
#         self.batch_size = 128
#         self.learning_rate = 0.01
#         self.epochs = 3
#         self.optimizer = optim.Adam(self.critic_network.parameters(), self.learning_rate)
#         self.grad_clip = 0.5
#         self.max_value = 2
#         self.coef = 0.5

#     def __call__(self, state):
#         clipped_output = torch.clamp(self.critic_network(state), min=-5, max=5)
#         return clipped_output
    
#     def update_policy(self, buffer, upp):
#         for _ in range(self.epochs):
#             losses = []
#             for batch in buffer.sample(self.batch_size):
#                 obs_batch, next_obs_batch, _, return_batch, _, mask, _, _, controller_logits, _, _, meta_controller_value_batch, meta_controller_tensor_batch, done = batch

#                 kl_div_per_sample = F.cross_entropy(controller_logits, meta_controller_tensor_batch, reduction='mean')

#                 v_next = self.critic_network(next_obs_batch)
#                 v_next = torch.minimum(v_next, torch.tensor(self.max_value, device=v_next.device))
#                 return_batch = return_batch.unsqueeze(1)
#                 max_q_value = torch.max(v_next).item()
#                 max_q_values = torch.full_like(return_batch, max_q_value, requires_grad=True)
#                 done = done.squeeze()
#                 one_minus_done = 1 - done
#                 one_minus_done = one_minus_done.unsqueeze(1)
#                 predicted_values = torch.minimum(torch.maximum(return_batch + self.gamma * max_q_values * one_minus_done - self.coef *kl_div_per_sample, torch.tensor(0, device=v_next.device)), torch.tensor(self.max_value, device=v_next.device))
#                 meta_controller_value_batch =  torch.minimum(torch.maximum(meta_controller_value_batch, torch.tensor(-self.max_value, device=meta_controller_value_batch.device)), torch.tensor(self.max_value, device=meta_controller_value_batch.device))
#                 loss = nn.functional.mse_loss(predicted_values, meta_controller_value_batch)
#                 loss.backward()
#                 torch.nn.utils.clip_grad_norm_(self.critic_network.parameters(), self.grad_clip)
#                 self.optimizer.step()
#                 losses.append(loss.item())
#         mean_losses = np.mean(losses, axis=0)
#         return mean_losses
    
