import numpy as np
import torch
from numpy import concatenate  
from torch import from_numpy
from torch.optim.adam import Adam

from Algorithms.Common.brain_common import BaseWorkflowController, Brain
from Algorithms.RND.config_rnd import RNDConfig
from Common.models.model_general import RNDModel
from Common.utils import RunningMeanStd, RunningMeanStdTorch

torch.backends.cudnn.benchmark = True



class RNDWorkflowController(BaseWorkflowController):
    config:RNDConfig

    def get_spec_training_data_list(self):
        return ["obs_next_states"]

class BrainRND(Brain):
    config:RNDConfig
    def set_rms(self):
        self.state_rms_torch = RunningMeanStdTorch(self.device, shape=self.obs_shape)
        self.int_reward_rms = RunningMeanStd(shape=(1,))
        self.rms_ckpt_list = ["state_rms_torch", "int_reward_rms"]
        self.statenorm_torch = (
            lambda x: torch.clamp((x - self.state_rms_torch.mean) / (self.state_rms_torch.var ** 0.5), -6, 6)) \
            if self.config.rnd_state_norm else (lambda x: x)

        self.rnd_reward_norm = lambda x: x / (self.int_reward_rms.var ** 0.5) if self.rnd_use_reward_norm else lambda \
            x: x

    def set_model_and_optimizer(self):
        self.policy = self.get_policy_cls()(self.config).to(self.device)
        self.rnd_predictor = RNDModel(self.config).to(self.device)
        self.rnd_target = RNDModel(self.config).to(self.device)
        for param in self.rnd_target.parameters():
            param.requires_grad = False
        self.total_trainable_params = list(self.policy.parameters()) + list(self.rnd_predictor.parameters())
        self.optimizer = Adam(self.total_trainable_params, lr=self.config.lr)
        self.model_ckpt_list = ["policy", "rnd_predictor", "rnd_target", "optimizer"]

    def preprocess_wf_training_data(self,wf_data_dict):
        if self.config.rnd_state_norm:
            wf_data_dict["obs_next_states_normed"]=self.statenorm_torch(from_numpy(wf_data_dict["obs_next_states"]).to(self.device)).cpu().numpy()
            return wf_data_dict
        else:
            wf_data_dict["obs_next_states_normed"]=wf_data_dict["obs_next_states"]
        return self.preprocess_delete(wf_data_dict)

    def preprocess_delete(self,wf_data_dict):
        del wf_data_dict["obs_next_states"]
        return wf_data_dict



    def calculate_rnd_loss(self, next_state):
        loss = (self.rnd_predictor(next_state) - self.rnd_target(next_state).detach()).pow(2).mean(-1)
        return loss.mean()

    def forward_rnd_reward(self,s_4ne:torch.Tensor,update_bn=True):
        f2ne = lambda f: (f.pow(2).mean(1).detach().cpu().numpy())
        with torch.no_grad():
            if update_bn and self.config.rnd_state_norm:
                self.state_rms_torch.update(s_4ne)
            s_4ne = self.statenorm_torch(s_4ne)
            rnd_reward=f2ne(self.rnd_predictor(s_4ne)-self.rnd_target(s_4ne))
        return rnd_reward

    def calc_spec_loss(self, batch_dict):
        rnd_loss = self.calculate_rnd_loss(batch_dict["obs_next_states_normed"])
        return rnd_loss, {"rnd_loss": rnd_loss.item()}

    def normalize_int_rewards(self, intrinsic_rewards, update_bn=True):
        if not self.config.rnd_use_reward_norm:
            return intrinsic_rewards

        if update_bn:
            gamma = self.config.int_gamma  # Make code faster.
            intrinsic_returns = np.zeros_like(intrinsic_rewards)
            rewems = 0
            for step in reversed(range(self.config.rollout_length)):
                rewems = rewems * gamma + intrinsic_rewards[:, step]
                intrinsic_returns[:, step] = rewems
            self.int_reward_rms.update(np.ravel(intrinsic_returns).reshape(-1, 1))

        return intrinsic_rewards / (self.int_reward_rms.var ** 0.5)

    def calc_rollout_int_reward(self,wf_controller:RNDWorkflowController, update_bn=True):
        s_4ne = from_numpy(concatenate(wf_controller.obs_next_states)).to(self.device)
        rnd_reward=self.reshp(self.forward_rnd_reward(s_4ne,update_bn))
        int_reward=self.normalize_int_rewards(rnd_reward,update_bn)
        return int_reward, {}



    def set_from_checkpoint(self, checkpoint):
        super().set_from_checkpoint(checkpoint)
        for param in self.rnd_target.parameters():
            param.requires_grad = False




