import torch
import typing

import numpy as np
import torch.nn as nn

from typing import List
import time
import torch.distributed as dist

from core.utils import TimeRecorder

class NetworkOutput(typing.NamedTuple):
    # output format of the model
    value: float
    value_prefix: float
    policy_logits: List[float]
    hidden_state: List[float]
    reward_hidden: object


def concat_output_value(output_lst):
    # concat the values of the model output list
    value_lst = []
    for output in output_lst:
        value_lst.append(output.value)

    value_lst = np.concatenate(value_lst)

    return value_lst


def concat_output(output_lst, state_gpu=False, reward_hidden_gpu=False):
    # concat the model output
    value_lst, reward_lst, policy_logits_lst, hidden_state_lst = [], [], [], []
    reward_hidden_c_lst, reward_hidden_h_lst =[], []
    for output in output_lst:
        value_lst.append(output.value)
        reward_lst.append(output.value_prefix)
        policy_logits_lst.append(output.policy_logits)
        hidden_state_lst.append(output.hidden_state)
        reward_hidden_c_lst.append(output.reward_hidden[0].squeeze(0))
        reward_hidden_h_lst.append(output.reward_hidden[1].squeeze(0))

    value_lst = np.concatenate(value_lst)
    reward_lst = np.concatenate(reward_lst)
    policy_logits_lst = np.concatenate(policy_logits_lst)
    # hidden_state_lst = torch.cat(hidden_state_lst, 0)
    if state_gpu:
        hidden_state_lst = torch.cat(hidden_state_lst)
    else:
        hidden_state_lst = np.concatenate(hidden_state_lst)
    if reward_hidden_gpu:
        reward_hidden_c_lst = torch.cat(reward_hidden_c_lst).unsqueeze(0)
        reward_hidden_h_lst = torch.cat(reward_hidden_h_lst).unsqueeze(0)
    else:
        reward_hidden_c_lst = np.expand_dims(np.concatenate(reward_hidden_c_lst), axis=0)
        reward_hidden_h_lst = np.expand_dims(np.concatenate(reward_hidden_h_lst), axis=0)

    return value_lst, reward_lst, policy_logits_lst, hidden_state_lst, (reward_hidden_c_lst, reward_hidden_h_lst)

def state_dict_to_cpu(state_dict):
    return {k: v.cpu() for k, v in state_dict.items()}

def get_ddp_model_weights(ddp_model):
    """
    Get weights of a DDP model
    """
    if hasattr(ddp_model, 'module'):
        ddp_model = ddp_model.module
    return {
        'representation_network': state_dict_to_cpu(ddp_model.representation_network.state_dict()),
        'dynamics_network': state_dict_to_cpu(ddp_model.dynamics_network.state_dict()),
        'prediction_network': state_dict_to_cpu(ddp_model.prediction_network.state_dict()),
    }

def load_model_weights(model, weights):
    model.representation_network.load_state_dict(weights['representation_network'])
    model.dynamics_network.load_state_dict(weights['dynamics_network'])
    model.prediction_network.load_state_dict(weights['prediction_network'])

class BaseNet(nn.Module):
    def __init__(self, inverse_value_transform, inverse_reward_transform, lstm_hidden_size):
        """Base Network
        schedule_timesteps. After this many timesteps pass final_p is
        returned.
        Parameters
        ----------
        inverse_value_transform: Any
            A function that maps value supports into value scalars
        inverse_reward_transform: Any
            A function that maps reward supports into value scalars
        lstm_hidden_size: int
            dim of lstm hidden
        """
        super(BaseNet, self).__init__()
        self.inverse_value_transform = inverse_value_transform
        self.inverse_reward_transform = inverse_reward_transform
        self.lstm_hidden_size = lstm_hidden_size

        # reward hidden
        self.__rank = None
        self.__reward_hidden = None
        self.__num = None

        self.__device = "cuda"
    
    def set_device(self, device):
        self.__device = device
    
    def reward_hidden(self, num, rank=None):
        if self.__reward_hidden is None or num != self.__num:
            self.__num = num
            self.__rank = rank
            self.__reward_hidden = (torch.zeros(1, num, self.lstm_hidden_size).to(self.__device if rank is None else rank),
                                    torch.zeros(1, num, self.lstm_hidden_size).to(self.__device if rank is None else rank))
        return self.__reward_hidden

    def prediction(self, state):
        raise NotImplementedError

    def representation(self, obs_history):
        raise NotImplementedError

    def dynamics(self, state, reward_hidden, action):
        raise NotImplementedError

    def initial_inference(self, obs, rank=None, to_numpy=True, state_only=False) -> NetworkOutput:
        num = obs.size(0)

        state = self.representation(obs)
        if state_only:
            if not self.training and to_numpy:
                state = state.detach().cpu().numpy()
            return NetworkOutput(None, None, None, state, None)
        actor_logit, value = self.prediction(state)
        # zero initialization for reward (value prefix) hidden states
        reward_hidden = self.reward_hidden(num, rank=rank)

        if not self.training:
            # if not in training, obtain the scalars of the value/reward
            value = self.inverse_value_transform(value)

            if to_numpy:
                value = value.detach().cpu().numpy()
                actor_logit = actor_logit.detach().cpu().numpy()
                state = state.detach().cpu().numpy()
                reward_hidden = (reward_hidden[0].detach().cpu().numpy(), reward_hidden[1].detach().cpu().numpy())
            else:
                value = value.detach().cpu().numpy()
                actor_logit = actor_logit.detach().cpu().numpy()
                state = state.detach()
                reward_hidden = (reward_hidden[0].detach(), reward_hidden[1].detach())

        return NetworkOutput(value, [0. for _ in range(num)], actor_logit, state, reward_hidden)

    def recurrent_inference(self, hidden_state, reward_hidden, action, to_numpy=True) -> NetworkOutput:
        state, reward_hidden, value_prefix = self.dynamics(hidden_state, reward_hidden, action)
        actor_logit, value = self.prediction(state)

        if not self.training:
            # if not in training, obtain the scalars of the value/reward
            value = self.inverse_value_transform(value)
            value_prefix = self.inverse_reward_transform(value_prefix)

            if to_numpy:
                value = value.detach().cpu().numpy()
                value_prefix = value_prefix.detach().cpu().numpy()
                state = state.detach().cpu().numpy()
                reward_hidden = (reward_hidden[0].detach().cpu().numpy(), reward_hidden[1].detach().cpu().numpy())
                actor_logit = actor_logit.detach().cpu().numpy()

        return NetworkOutput(value, value_prefix, actor_logit, state, reward_hidden)

    
    def train_initial_inference(self, obs, rank=None) -> NetworkOutput:
        num = obs.size(0)

        state = self.representation(obs)
        reward_hidden = self.reward_hidden(num, rank=rank)

        return NetworkOutput(None, None, None, state, reward_hidden)

    def train_recurrent_inference(self, hidden_state, reward_hidden, action) -> NetworkOutput:
        state, reward_hidden, value_prefix = self.dynamics(hidden_state, reward_hidden, action)

        return NetworkOutput(None, value_prefix, None, state, reward_hidden)

    def get_weights(self):
        return {    
            'representation_network': self.representation_network.state_dict(),
            'dynamics_network': self.dynamics_network.state_dict(),
            'prediction_network': self.prediction_network.state_dict(),
        }

    def set_weights(self, weights):
        self.representation_network.load_state_dict(weights['representation_network'])
        self.dynamics_network.load_state_dict(weights['dynamics_network'])
        self.prediction_network.load_state_dict(weights['prediction_network'])

    def get_gradients(self):
        grads = []
        for p in self.parameters():
            grad = None if p.grad is None else p.grad.data.cpu().numpy()
            grads.append(grad)
        return grads

    def set_gradients(self, gradients):
        for g, p in zip(gradients, self.parameters()):
            if g is not None:
                p.grad = torch.from_numpy(g)
        
    def forward(self, config, rank, 
                obs_batch, action_batch, obs_target_batch, mask_batch, 
                consist_loss_func,
                target_policy,
                target_value_phi, target_value_prefix_phi,
                indices):
        # tick = TimeRecorder()
        # tick.tick()
        batch_size = obs_batch.size(0)
        _, _, _, hidden_state, reward_hidden = self.train_initial_inference(obs_batch, rank)
        hidden_states = [hidden_state]
        value_prefix = []
        policy_logits = []
        value = []
        # loss
        # value_loss
        # policy_loss
        # value_prefix_loss
        consistency_loss = torch.zeros(batch_size).to(rank)

        p, v = self.prediction(hidden_state)
        policy_logits.append(p)
        value.append(v)

        for step_i in range(config.num_unroll_steps):
            _, step_value_prefix, _, hidden_state, reward_hidden = self.train_recurrent_inference(hidden_state, reward_hidden, action_batch[:, step_i])
            
            hidden_state.register_hook(lambda grad: grad * 0.5)

            beg_index = config.image_channel * step_i
            end_index = config.image_channel * (step_i + config.stacked_observations)

            # consistency loss
            # obtain the oracle hidden states from representation function
            presentation_state = self.representation(obs_target_batch[:, beg_index:end_index, :, :])
            # no grad for the presentation_state branch
            dynamic_proj = self.project(hidden_state, with_grad=True)
            observation_proj = self.project(presentation_state, with_grad=False)
            temp_loss = consist_loss_func(dynamic_proj, observation_proj) * mask_batch[:, step_i + 1]
            consistency_loss += temp_loss
            
            # policy and value
            # iterative
            p, v = self.prediction(hidden_state)
            policy_logits.append(p)
            value.append(v)

            hidden_states.append(hidden_state)
            value_prefix.append(step_value_prefix)

        hidden_states = torch.stack(hidden_states, dim=0) # [num_unroll_steps + 1, B, D]
        value_prefix = torch.stack(value_prefix, dim=0) # [num_unroll_steps, B, D]
        
        # iterative
        policy_logits = torch.stack(policy_logits, dim=0)
        value = torch.stack(value, dim=0)

        # loss
        policy_loss = -(torch.log_softmax(policy_logits, dim=-1) * target_policy.transpose(0, 1)).sum(dim=2).sum(dim=0)
        value_loss = -(torch.log_softmax(value, dim=-1) * target_value_phi.transpose(0, 1)).sum(dim=2).sum(dim=0)
        value_prefix_loss = -(torch.log_softmax(value_prefix, dim=-1) * target_value_prefix_phi.transpose(0, 1)[:config.num_unroll_steps]).sum(dim=2).sum(dim=0)

        return policy_loss, value_loss, value_prefix_loss, consistency_loss, hidden_states, policy_logits, value, value_prefix, [0.]


def renormalize(tensor, first_dim=1):
    # normalize the tensor (states)
    if first_dim < 0:
        first_dim = len(tensor.shape) + first_dim
    flat_tensor = tensor.view(*tensor.shape[:first_dim], -1)
    max = torch.max(flat_tensor, first_dim, keepdim=True).values
    min = torch.min(flat_tensor, first_dim, keepdim=True).values
    flat_tensor = (flat_tensor - min) / (max - min)

    return flat_tensor.view(*tensor.shape)