import numpy as np
import torch
from numpy import concatenate  
from torch import from_numpy

from Algorithms.Common.brain_common import BaseWorkflowController
from Algorithms.NovelD.config_noveld import NovelDConfig
from Algorithms.RND.brain_rnd import BrainRND

torch.backends.cudnn.benchmark = True

class NovelDWorkflowController(BaseWorkflowController):
    config:NovelDConfig
    def get_spec_training_data_list(self):
        return ["obs_next_states"]

class BrainNovelD(BrainRND):
    config:NovelDConfig
    def forward_novelty(self,s_4ne:torch.Tensor,update_bn=True):
        f2ne = lambda f: (f.pow(2).mean(1).pow(0.5).detach().cpu().numpy())
        with torch.no_grad():
            if update_bn and self.config.rnd_state_norm:
                self.state_rms_torch.update(s_4ne)

            s_4ne = self.statenorm_torch(s_4ne)
            novelty=f2ne(self.rnd_predictor(s_4ne)-self.rnd_target(s_4ne))
        if self.config.rnd_use_reward_norm:
            if update_bn:
                self.int_reward_rms.update(novelty)
            novelty=novelty/(self.int_reward_rms.var**0.5)
        return novelty

    def calc_rollout_int_reward(self,wf_controller:NovelDWorkflowController, update_bn=True):
        s_4ne = from_numpy(concatenate(wf_controller.obs_next_states)).to(self.device)
        s4ne = from_numpy(concatenate(wf_controller.obs_states)).to(self.device)
        novelty_s_ = self.reshp(self.forward_novelty(s_4ne, update_bn))
        novelty_s = self.reshp(self.forward_novelty(s4ne, False))

        erir_mask_ = wf_controller.erir_mask if self.config.use_erir else 1.0
        int_reward = np.clip(novelty_s_ - novelty_s * self.config.noveld_alpha, 0, 1000) * erir_mask_
        return int_reward, {"next_state_novelty":novelty_s_}
