import time
from typing import Optional

import numpy as np
import torch
from numpy import concatenate  
from torch import from_numpy
from torch.optim.adam import Adam

from Common import Config
from Common.models.model_general import PolicyModel, PolicyModelRNN
from Common.utils import RunningMeanStd, clip_grad_norm_, explained_variance, RunningMeanStdTorch

torch.backends.cudnn.benchmark = True


class BaseWorkflowController:
    states: np.ndarray
    next_states: np.ndarray
    obs_states: np.ndarray
    obs_next_states: np.ndarray
    cell_states: np.ndarray
    cell_next_states: np.ndarray
    actions: np.ndarray
    log_probs: np.ndarray
    action_probs: np.ndarray
    int_rewards: np.ndarray
    ext_rewards: np.ndarray
    ext_rewards_gt: np.ndarray
    erir_mask: np.ndarray
    episode_steps: np.ndarray
    dones: np.ndarray
    int_values: np.ndarray
    ext_values: np.ndarray
    int_rets: np.ndarray
    ext_rets: np.ndarray
    int_advs: np.ndarray
    ext_advs: np.ndarray
    advs: np.ndarray
    rollout_last_states: np.ndarray
    rollout_last_ext_values: np.ndarray
    rollout_last_int_values: np.ndarray
    training_data_attr_list:list
    hidden_states:Optional[np.ndarray]
    def __init__(self, config: Config):
        self.config = config
        self.attr_dict = {
            "states": (config.state_shape, np.uint8),
            "next_states": (config.state_shape, np.uint8),
            "obs_states": (config.obs_shape, np.uint8),
            "obs_next_states": (config.obs_shape, np.uint8),
            "cell_states": (config.cell_shape, np.uint8),
            "cell_next_states": (config.cell_shape, np.uint8),
            "actions": ((), np.uint8),
            "log_probs": ((), np.float32),
            "action_probs": ((config.n_actions,), np.float32),
            "int_rewards": ((), np.float32),
            "ext_rewards": ((), np.float32),
            "ext_rewards_gt": ((), np.float32),
            "erir_mask": ((), np.uint8),
            "episode_steps": ((), np.float32),
            "dones": ((), bool),
            "int_values": ((), np.float32),
            "ext_values": ((), np.float32),
            "int_rets": ((), np.float32),
            "ext_rets": ((), np.float32),
            "int_advs": ((), np.float32),
            "ext_advs": ((), np.float32),
            "advs": ((), np.float32),
        }
        if config.policy_use_rnn:
            self.attr_dict.update(
                {
                    "hidden_states":((config.extra_hidden_size,),np.float32),
                }
            )
            self.rollout_last_hidden_states=np.zeros((config.n_workers,config.extra_hidden_size), dtype=np.float32)
        self.rollout_base_shape = (config.n_workers, config.rollout_length)
        self.attr_dict.update(self.get_add_attr_dict())
        for k, v in self.attr_dict.items():
            setattr(self, k, np.zeros((self.rollout_base_shape + v[0]), dtype=v[1]))
        self.rollout_last_states = np.zeros((config.n_workers,) + config.state_shape, dtype=np.uint8)
        self.rollout_last_ext_values = np.zeros((config.n_workers,), dtype=np.float32)
        self.rollout_last_int_values = np.zeros((config.n_workers,), dtype=np.float32)
        self.set_training_data_list()
        self.set_adv_norm()

    def set_adv_norm(self):
        if self.config.adv_norm_type>0:
            self.adv_rms=RunningMeanStd(shape=(1,),momentum=0.9)

    def get_add_attr_dict(self):
        return {}

    def set_training_data_list(self):
        self.training_data_attr_list = ["states", "actions", "int_rets", "ext_rets", "advs", "log_probs"]
        if self.config.policy_use_rnn:
            self.training_data_attr_list.append("hidden_states")
        self.training_data_attr_list.extend(self.get_spec_training_data_list())

    def get_spec_training_data_list(self):
        return []

    def reset(self):
        for k, v in self.attr_dict.items():
            setattr(self, k, np.zeros((self.rollout_base_shape + v[0]), dtype=v[1]))
        self.rollout_last_states = np.zeros((self.config.n_workers,) + self.config.state_shape, dtype=np.uint8)
        self.rollout_last_ext_values = np.zeros((self.config.n_workers,), dtype=np.float32)
        self.rollout_last_int_values = np.zeros((self.config.n_workers,), dtype=np.float32)

    def get_gae(self, rewards, values, next_values, dones, gamma,lam=None):
        if lam is None:
            lam = self.config.lam

        if len(rewards.shape) == 1:
            rewards = np.expand_dims(rewards, 0)
            values = np.expand_dims(values, 0)
            next_values = np.expand_dims(next_values, 0)
            dones = np.expand_dims(dones, 0)
        n_workers, length = rewards.shape

        returns = np.zeros_like(values)
        extended_values = np.concatenate([values, np.expand_dims(next_values, 1)], axis=1)
        gae = np.zeros_like(next_values)
        for step in reversed(range(length)):
            delta = rewards[:, step] + gamma * (extended_values[:, step + 1]) * (1 - dones[:, step]) \
                    - extended_values[:, step]
            gae = delta + gamma * lam * (1 - dones[:, step]) * gae
            returns[:, step] = gae + extended_values[:, step]

        return returns

    def calc_gae(self):
        self.int_rets = self.get_gae(self.int_rewards, self.int_values, self.rollout_last_int_values,
                                    self.process_done(), self.config.int_gamma)
        self.ext_rets = self.get_gae(self.ext_rewards, self.ext_values, self.rollout_last_ext_values, self.dones,
                                     self.config.ext_gamma)
        self.int_advs = self.int_rets - self.int_values
        self.ext_advs = self.ext_rets - self.ext_values
        advs = self.ext_advs * self.config.ext_adv_coeff + self.int_advs * self.config.int_adv_coeff
        if self.config.adv_norm_type==0:
            self.advs=advs
        else:
            self.adv_rms.update(np.concatenate(advs))
            if self.config.adv_norm_type==1:
                self.advs=(advs-self.adv_rms.mean)/(self.adv_rms.var**0.5+1e-5)
            elif self.config.adv_norm_type==2:
                self.advs = advs / (self.adv_rms.var ** 0.5 + 1e-5)

    def process_done(self):
        if hasattr(self.config,"rnd_int_value_non_episodic"):
            return  self.dones * (1-self.config.rnd_int_value_non_episodic)
        return self.dones


    def get_training_data(self):
        return {i:concatenate(getattr(self,i)) for i in self.training_data_attr_list}

    def get_ev_info_dict(self):
        return{
            "Int":explained_variance(concatenate(self.int_values),concatenate(self.int_rets)),
            "Ext": explained_variance(concatenate(self.ext_values), concatenate(self.ext_rets)),
        }

class Brain:
    total_trainable_params:list
    optimizer: torch.optim.Optimizer
    model_ckpt_list:list
    rms_ckpt_list:list
    policy: PolicyModel
    def __init__(self, config: Config):
        self.config = config
        self.mini_batch_size = self.config.batch_size
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.obs_shape = self.config.obs_shape
        self.mse_loss = torch.nn.MSELoss()

        self.set_model_and_optimizer()
        self.set_rms()
        self.reshp=lambda x:x.reshape((self.config.n_workers, self.config.rollout_length)) if len(x.shape)==1 else x.reshape((self.config.n_workers, self.config.rollout_length,-1))

    def get_policy_cls(self):
        if self.config.policy_use_rnn:
            return PolicyModelRNN
        else:
            return PolicyModel

    def set_model_and_optimizer(self):
        self.policy = self.get_policy_cls()(self.config).to(self.device)
        self.total_trainable_params=list(self.policy.parameters())
        self.optimizer=Adam(self.total_trainable_params,lr=self.config.lr)
        self.model_ckpt_list = ["policy", "optimizer"]

    def print_brain(self):
        for i in self.model_ckpt_list:
            if i!="optimizer":
                print(getattr(self,i))

    def set_rms(self):
        self.state_rms_torch = RunningMeanStdTorch(self.device, shape=self.obs_shape)
        self.int_reward_rms = RunningMeanStd(shape=(1,))
        self.rms_ckpt_list = ["state_rms_torch", "int_reward_rms"]

    def get_actions_and_values(self, state, batch=False):
        if not batch:
            state = np.expand_dims(state, 0)
        state = from_numpy(state).to(self.device)
        with torch.no_grad():
            dist, int_value, ext_value, action_prob = self.policy(state)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            log_prob_all = dist.logits
        return action.cpu().numpy(), int_value.cpu().numpy().squeeze(), \
            ext_value.cpu().numpy().squeeze(), log_prob.cpu().numpy(), \
            action_prob.cpu().numpy(), log_prob_all.cpu().numpy()

    def get_actions_and_values_rnn(self, state, hidden_state, batch=False):
        if not batch:
            state = np.expand_dims(state, 0)
            hidden_state = np.expand_dims(hidden_state, 0)
        state = from_numpy(state).to(self.device)
        hidden_state = from_numpy(hidden_state).to(self.device)
        with torch.no_grad():
            dist, int_value, ext_value, action_prob,h = self.policy([state,hidden_state])
            action = dist.sample()
            log_prob = dist.log_prob(action)
            log_prob_all = dist.logits
        return action.cpu().numpy(), int_value.cpu().numpy().squeeze(), \
            ext_value.cpu().numpy().squeeze(), log_prob.cpu().numpy(), \
            action_prob.cpu().numpy(), log_prob_all.cpu().numpy(),h.cpu().numpy()

    def choose_mini_batch(self, training_data_dict):
        new_args_dict = {
            k:(torch.Tensor(arg) if arg.dtype != np.uint8 else torch.ByteTensor(arg)).to(self.device)
            for k,arg in training_data_dict.items()
        }
        rand_idx = np.arange(len(list(new_args_dict.values())[0]))
        np.random.shuffle(rand_idx)
        assert len(rand_idx) == self.config.n_mini_batch * self.mini_batch_size
        indices = rand_idx.reshape((self.config.n_mini_batch,
                                    self.mini_batch_size))

        for idx in indices:
            yield {k:new_arg[idx] for k,new_arg in new_args_dict.items()}

    def preprocess_wf_training_data(self,wf_data_dict):
        return wf_data_dict

    def calc_rl_loss_fwd_policy(self,batch_dict):
        if self.config.policy_use_rnn:
            return self.policy([batch_dict["states"],batch_dict["hidden_states"]])
        else:
            return self.policy(batch_dict["states"])

    def calc_rl_loss(self,batch_dict):
        dist, int_value, ext_value, *_ = self.calc_rl_loss_fwd_policy(batch_dict)
        entropy = dist.entropy().mean()
        new_log_prob = dist.log_prob(batch_dict["actions"])
        ratio = (new_log_prob - batch_dict["log_probs"]).exp()
        pg_loss = self.compute_pg_loss(ratio, batch_dict["advs"])
        int_value_loss = self.mse_loss(int_value.squeeze(-1), batch_dict["int_rets"])
        ext_value_loss = self.mse_loss(ext_value.squeeze(-1), batch_dict["ext_rets"])
        critic_loss = 0.5 * (int_value_loss + ext_value_loss)
        rl_loss= critic_loss + pg_loss - self.config.ent_coeff * entropy
        return rl_loss,{
            "PG Loss": pg_loss.item(),
            "Entropy": entropy.item(),
            "Int Value Loss": int_value_loss.item(),
            "Ext Value Loss": ext_value_loss.item(),
        }

    def calc_spec_loss(self,batch_dict):
        return 0.0,{}

    def train(self, wf_controller: BaseWorkflowController):
        BN_START_TIME = time.time()
        training_data_dict=self.preprocess_wf_training_data(wf_controller.get_training_data())
        BN_DONE_TIME = time.time()

        forward_times, backward_times = [], []
        training_log_dict={}
        for epoch in range(self.config.n_epochs):
            for batch_dict in self.choose_mini_batch(training_data_dict):
                FORWARD_START_TIME = time.time()

                rl_loss,rl_loss_info = self.calc_rl_loss(batch_dict)
                spec_loss, spec_loss_info = self.calc_spec_loss(batch_dict)

                total_loss =  rl_loss+ spec_loss
                BACKWARD_START_TIME = time.time()
                self.optimize(total_loss)
                BACKWARD_DONE_TIME = time.time()
                for k,v in {**rl_loss_info,**spec_loss_info}.items():
                    if k not in training_log_dict:
                        training_log_dict[k] = [v]
                    else:
                        training_log_dict[k].append(v)

                forward_times.append(BACKWARD_START_TIME - FORWARD_START_TIME)
                backward_times.append(BACKWARD_DONE_TIME - BACKWARD_START_TIME)

        return_dict = {
            f"RL Loss/{k}":np.mean(v) for k,v in training_log_dict.items()
        }
        return_dict.update({
            "Time Consumption/RND State BN Time": BN_DONE_TIME - BN_START_TIME,
            "Time Consumption/Train Forward Time": np.sum(forward_times),
            "Time Consumption/Train Backward Time": np.sum(backward_times),
        })

        return_dict.update({
            f"Explained Variance/{k}":v for k,v in wf_controller.get_ev_info_dict().items()
        })

        return return_dict

    def optimize(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(self.total_trainable_params)
        self.optimizer.step()

    def calc_rollout_int_reward(self, wf_controller:BaseWorkflowController, update_bn=True):
        return np.zeros_like(wf_controller.int_rewards),{}

    def compute_pg_loss(self, ratio, adv):
        new_r = ratio * adv
        clamped_r = torch.clamp(ratio, 1 - self.config.clip_range, 1 + self.config.clip_range) * adv
        loss = torch.min(new_r, clamped_r)
        loss = -loss.mean()
        return loss

    def set_from_checkpoint(self, checkpoint):
        for model_ckpt in self.model_ckpt_list:
            getattr(self, model_ckpt).load_state_dict(checkpoint[f"{model_ckpt}_state_dict"])

        for rms_ckpt in self.rms_ckpt_list:
            getattr(self, rms_ckpt).set_from_checkpoint(checkpoint[f"{rms_ckpt}_state_dict"])

    def save_model(self, file_name, info):
        checkpoint = {f"{model_ckpt}_state_dict": getattr(self, model_ckpt).state_dict() for model_ckpt in
                      self.model_ckpt_list}
        checkpoint.update({
            f"{k0}_state_dict": {k: getattr(getattr(self,k0), k) for k in "mean var count".split(" ")}
            for k0 in self.rms_ckpt_list
        })
        checkpoint.update(info)
        torch.save(checkpoint, file_name)

    def set_to_eval_mode(self):
        self.policy.eval()
