import numpy as np
import torch
import torch.nn.functional as F
from numpy import concatenate  
from torch import from_numpy
from torch.optim.adam import Adam

from Algorithms.Common.brain_common import BaseWorkflowController, Brain
from Algorithms.Slav3.config_slav3 import Slav3Config
from Algorithms.Slav3.model_sla import SlaValueModel, SlaValueRNDModel
from Common.models.model_general import RNDModel
from Common.utils import RunningMeanStd, explained_variance, RunningMeanStdTorch

torch.backends.cudnn.benchmark = True

class Slav3WorkflowController(BaseWorkflowController):
    sla_is_starts:np.ndarray
    sla_values: np.ndarray
    sla_next_values:np.ndarray
    sla_next_rets:np.ndarray
    sla_int_rewards: np.ndarray
    config:Slav3Config
    
    def get_sla_value_state(self,is_next_state:bool):
        return getattr(self,self.config.get_sla_state_str(is_next_state=is_next_state))

    def get_add_attr_dict(self):
        return {
            "sla_is_starts": ((), bool),
            "sla_values": ((), np.float32),
            "sla_next_values": ((), np.float32),
            "sla_next_rets": ((), np.float32),
            "sla_int_rewards": ((), np.float32),
        }

    def get_spec_training_data_list(self):
        base_list=["sla_next_rets",self.config.get_sla_state_str(is_next_state=True)]
        if self.config.sla_rnd_inte_type>0:
            base_list+= ["obs_next_states"]
        return list(set(base_list))



    def calc_sla_gae(self):
        self.sla_next_rets = self.config.sla_reward*self.episode_steps

    def get_ev_info_dict(self):
        return{
            "Int":explained_variance(concatenate(self.int_values),concatenate(self.int_rets)),
            "Ext": explained_variance(concatenate(self.ext_values), concatenate(self.ext_rets)),
            "Sla": explained_variance(concatenate(self.sla_next_values), concatenate(self.sla_next_rets)),
        }

class BrainSlav3(Brain):
    config:Slav3Config
    sla_value_model:SlaValueModel # sla heuristic model
    def set_rms(self):
        self.rms_ckpt_list=[]
        if self.config.sla_rnd_inte_type>0:
            self.state_rms_torch = RunningMeanStdTorch(self.device, shape=self.obs_shape)
            self.int_reward_rms = RunningMeanStd(shape=(1,))
            self.rms_ckpt_list.extend(["state_rms_torch", "int_reward_rms"])
        self.statenorm_torch = (
            lambda x: torch.clamp((x - self.state_rms_torch.mean) / (self.state_rms_torch.var ** 0.5), -6, 6)) \
            if self.config.rnd_state_norm else (lambda x: x)

        self.rnd_reward_norm = lambda x: x / (self.int_reward_rms.var ** 0.5) if self.config.rnd_use_reward_norm else lambda \
            x: x

    def get_sla_model_cls(self):
        if self.config.sla_value_model_type==0:
            return SlaValueModel
        elif self.config.sla_value_model_type==1:
            return SlaValueRNDModel
        else:
            raise NotImplementedError

    def set_model_and_optimizer(self):
        self.policy = self.get_policy_cls()(self.config).to(self.device)
        self.sla_value_model = self.get_sla_model_cls()(self.config).to(self.device)
        self.total_trainable_params = list(self.policy.parameters()) + list(
            self.sla_value_model.parameters())
        self.model_ckpt_list = ["policy", "sla_value_model","optimizer"]
        if self.config.sla_rnd_inte_type>0:
            self.rnd_predictor = RNDModel(self.config).to(self.device)
            self.rnd_target = RNDModel(self.config).to(self.device)
            for param in self.rnd_target.parameters():
                param.requires_grad = False
            self.total_trainable_params += list(self.rnd_predictor.parameters())
            self.model_ckpt_list.extend(["rnd_predictor", "rnd_target"])
        self.optimizer = Adam(self.total_trainable_params, lr=self.config.lr)


    def preprocess_wf_training_data(self,wf_data_dict):
        if self.config.sla_rnd_inte_type>0:
            wf_data_dict["obs_next_states_normed"]=wf_data_dict["obs_next_states"] if not self.config.rnd_state_norm else (
                self.statenorm_torch(from_numpy(wf_data_dict["obs_next_states"]).to(self.device)).cpu().numpy())
        return wf_data_dict

    def calculate_rnd_loss(self, next_state):
        loss = (self.rnd_predictor(next_state) - self.rnd_target(next_state).detach()).pow(2).mean(-1)
        return loss.mean()

    def calc_spec_loss(self,batch_dict):
        loss,loss_info=0.0,{}
        sla_next_value = self.sla_value_model(batch_dict[self.config.get_sla_state_str(is_next_state=True)]).squeeze(-1)
        sla_adv_loss_weight = F.leaky_relu(torch.sign(sla_next_value - batch_dict["sla_next_rets"]),
                                           negative_slope=-1.0 / self.config.sla_shortest_coeff)
        sla_adv_loss = (sla_adv_loss_weight * (sla_next_value - batch_dict["sla_next_rets"]) ** 2).mean()
        loss += sla_adv_loss
        loss_info.update({"sla_adv_loss": sla_adv_loss.item()})

        if self.config.sla_rnd_inte_type>0:
            rnd_loss = self.calculate_rnd_loss(batch_dict["obs_next_states_normed"])
            loss+=rnd_loss
            loss_info.update({"rnd_loss": rnd_loss.item()})

        return loss,loss_info


    def forward_novelty(self,s_4ne:torch.Tensor,update_bn=True):
        f2ne = lambda f: (f.pow(2).mean(1).pow(0.5).detach().cpu().numpy())
        with torch.no_grad():
            if update_bn and self.config.rnd_state_norm:
                self.state_rms_torch.update(s_4ne)
            s_4ne = self.statenorm_torch(s_4ne)
            novelty=f2ne(self.rnd_predictor(s_4ne)-self.rnd_target(s_4ne))
        if self.config.rnd_use_reward_norm:
            if update_bn:
                self.int_reward_rms.update(novelty)
            novelty=novelty/(self.int_reward_rms.var**0.5)

        return novelty

    
    def forward_sla_value(self,state):
        sla_value=self.sla_value_model(state).squeeze(-1).detach().cpu().numpy()
        return sla_value

    def calculate_sla_value(self,s4sla:np.ndarray,s_4sla:np.ndarray, batch=True)->[np.ndarray, np.ndarray]:
        if not batch:
            s4sla,s_4sla=np.expand_dims(s4sla,0),np.expand_dims(s_4sla,0)
        s_all=from_numpy(np.concatenate([s4sla,s_4sla],axis=0)).to(self.device)
        with torch.no_grad():
            sla_value_all=self.forward_sla_value(s_all)
        if not batch:return sla_value_all[0],sla_value_all[1]
        return np.split(sla_value_all,[s4sla.shape[0]],axis=0)

    def calc_rollout_int_reward(self, wf_controller:Slav3WorkflowController, update_bn=True):
        int_rewards = wf_controller.sla_int_rewards

        if self.config.sla_use_erir_mask:
            int_rewards+=self.config.sla_erir_coeff*self.config.sla_reward*wf_controller.erir_mask

        if self.config.sla_rnd_inte_type>0:
            s_4ne = from_numpy(concatenate(wf_controller.obs_next_states)).to(self.device)
            novelty_s_ = self.reshp(self.forward_novelty(s_4ne, update_bn))
            int_rewards=int_rewards*np.clip(novelty_s_,self.config.sla_rnd_lb,1000)
        return int_rewards,{}

    def set_from_checkpoint(self, checkpoint):
        for model_ckpt in self.model_ckpt_list:
            getattr(self, model_ckpt).load_state_dict(checkpoint[f"{model_ckpt}_state_dict"])
        for rms_ckpt in self.rms_ckpt_list:
            getattr(self, rms_ckpt).set_from_checkpoint(checkpoint[f"{rms_ckpt}_state_dict"])
        if self.config.sla_rnd_inte_type>0:
            for param in self.rnd_target.parameters():
                param.requires_grad = False
