import torch
from numpy import concatenate  
from torch import from_numpy
from torch.optim.adam import Adam

from Algorithms.Common.brain_common import BaseWorkflowController, Brain
from Algorithms.RIDE.config_ride import RIDEConfig
from Algorithms.RIDE.model_ride import RIDEModel

torch.backends.cudnn.benchmark = True

class RIDEWorkflowController(BaseWorkflowController):
    def get_spec_training_data_list(self):
        return ["next_states"]

class BrainRIDE(Brain):
    config:RIDEConfig

    def set_model_and_optimizer(self):
        self.policy = self.get_policy_cls()(self.config).to(self.device)
        self.ride_model = RIDEModel(self.config).to(self.device)
        self.total_trainable_params = list(self.policy.parameters()) + list(self.ride_model.parameters())
        self.optimizer = Adam(self.total_trainable_params, lr=self.config.lr)
        self.model_ckpt_list = ["policy", "ride_model", "optimizer"]


    def calc_spec_loss(self, batch_dict):
        ride_loss,info = self.ride_model.loss_process(batch_dict["states"],batch_dict["next_states"],batch_dict["actions"])
        return ride_loss,info


    def calc_rollout_int_reward(self,wf_controller:BaseWorkflowController, update_bn=True):
        states=from_numpy(concatenate(wf_controller.states)).to(self.device)
        next_states=from_numpy(concatenate(wf_controller.next_states)).to(self.device)
        actions=from_numpy(concatenate(wf_controller.actions)).long().to(self.device)
        int_reward,_=self.ride_model.reward_process(states,next_states,actions)
        int_reward=self.reshp(int_reward.detach().cpu().numpy())
        int_reward=int_reward*wf_controller.erir_mask
        return int_reward, {}





