import numpy as np
import torch
from torch import from_numpy
from torch.optim.adam import Adam

from Algorithms.Common.brain_common import BaseWorkflowController, Brain
from Algorithms.ETD.config_etd import ETDConfig
from Algorithms.ETD.model_etd import ETDModel
from Common.utils import RunningMeanStd

torch.backends.cudnn.benchmark = True
class ETDWorkflowController(BaseWorkflowController):
    config:ETDConfig
    future_states:np.ndarray
    etd_emb_states:np.ndarray
    etd_emb_next_states:np.ndarray
    etd_dists:np.ndarray

    def get_add_attr_dict(self):
        etd_hidden_size=self.config.etd_model_kwargs["etd_encoder_arch"][-1] if "etd_encoder_arch" in self.config.etd_model_kwargs else self.config.extra_hidden_size
        return {
            "future_states":(self.config.state_shape,np.uint8),
            "etd_emb_states":((etd_hidden_size,),np.float32),
            "etd_emb_next_states": ((etd_hidden_size,), np.float32),
            "etd_dists":((),np.float32)
        }
    def get_spec_training_data_list(self):
        return ["future_states"]

    def get_future_states(self):
        N,T=self.config.n_workers,self.config.rollout_length
        ranges=np.full((N,T),T,dtype=np.int32)
        last_one=np.full((N,),T-1,dtype=np.int32)
        for t in reversed(range(T)):
            last_one = np.where(self.dones[:, t], t, last_one)
            ranges[:,t]=last_one-t
        seed=np.random.rand(*ranges.shape)
        fs_gamma=self.config.future_sample_gamma
        if fs_gamma==0:
            intervals=np.ones_like(ranges)
        elif fs_gamma==1:
            intervals=np.ceil(seed*ranges).astype(int)
        else:
            intervals=np.log(1-(1-fs_gamma**ranges)*seed)/np.log(fs_gamma)
        intervals=np.minimum(np.ceil(intervals).astype(int),ranges)
        self.future_states=self.next_states[np.arange(N)[:,None],intervals+np.arange(T)[None,:]]
        assert self.future_states.shape==self.next_states.shape




class BrainETD(Brain):
    config:ETDConfig

    def set_model_and_optimizer(self):
        self.policy = self.get_policy_cls()(self.config).to(self.device)
        self.etd_model = ETDModel(self.config).to(self.device)
        self.total_trainable_params = list(self.policy.parameters()) + list(self.etd_model.parameters())
        self.optimizer = Adam(self.total_trainable_params, lr=self.config.lr)
        self.model_ckpt_list = ["policy", "etd_model", "optimizer"]

    def set_rms(self):
        self.int_reward_rms=RunningMeanStd(shape=(1,),momentum=self.config.etd_int_reward_momentum)
        self.rms_ckpt_list=["int_reward_rms"]


    def calc_spec_loss(self, batch_dict):
        etd_loss,info = self.etd_model.loss_process(batch_dict["states"],batch_dict["future_states"])
        return etd_loss,info


    def calc_rollout_int_reward(self,wf_controller:ETDWorkflowController, update_bn=True):
        if self.config.etd_use_int_reward_rms:
            if update_bn:
                self.int_reward_rms.update(np.concatenate(wf_controller.etd_dists))

            int_rewards=(wf_controller.etd_dists-self.int_reward_rms.mean)/(self.int_reward_rms.var**0.5+1e-4)
            return int_rewards, {}
        else:
            return wf_controller.etd_dists*1.0,{}

    def get_etd_embedding(self,states,next_states,batch=True):
        if not batch:
            states,next_states=np.expand_dims(states,0),np.expand_dims(next_states,0)
        states_all=from_numpy(np.concatenate([states,next_states],axis=0)).to(self.device)
        with torch.no_grad():
            etd_emb_all=self.etd_model.etd_embedding_process(states_all).cpu().numpy()
        if not batch: return etd_emb_all[0],etd_emb_all[1]
        return np.split(etd_emb_all,[states.shape[0]],axis=0)









