import numpy as np
import torch
import torch.nn as nn

from torch.autograd import Function
from torch.nn import functional as F
from typing import Dict, List, Union, Tuple, Optional
from copy import deepcopy


import torch
import torch.nn as nn
import torch.nn.functional as F


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


class Swish(nn.Module):
    def forward(self, x):
        return x * F.sigmoid(x)

class CondNet(nn.Module):
    def __init__(self,
                 state_size,
                 action_size,
                 reward_size,
                 hidden_size,
                 discount: float = 0.99,
                 adv_scale: float = 1.0,
                 curiosity: bool = False,
                 td_guide: bool = False,
                 density_ratio: bool = False,
                #  residual: bool = False,
                 soft_update_rate: float = 1.0):
        super(CondNet, self).__init__()

        self.resnet_time = 4
        self.state_size = state_size
        self.action_size = action_size
        self.reward_size = reward_size
        self.hidden_size = hidden_size
        self.adv_scale = adv_scale
        self.curiosity = curiosity
        self.density_ratio = density_ratio
        self.discount = discount
        self.td = td_guide
        # self.residual = residual
        self.soft_update_rate = soft_update_rate

        self.feature = nn.Sequential(
            nn.Linear(self.state_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
        )
        self.target_feature = deepcopy(self.feature)

        self.inverse_net = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.BatchNorm1d(self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.BatchNorm1d(self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.action_size)
        )

        self.residual = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.action_size + self.hidden_size, self.hidden_size),
                Swish(),
                nn.Linear(self.hidden_size, self.hidden_size),
                )] * 2 * self.resnet_time
        )

        self.forward_net_1 = nn.Sequential(
            nn.Linear(self.action_size + self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size)
        )
        self.forward_net_2 = nn.Sequential(
            nn.Linear(self.action_size + self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size)
        )
        self.transition_proj = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.hidden_size),
            Swish(),
            nn.Linear(self.hidden_size, self.state_size + self.reward_size),
        )

        if density_ratio:
            self.density_fn = nn.Sequential(
                nn.Linear(2*state_size+action_size+1, self.hidden_size),
                nn.ReLU(),
                nn.Linear(self.hidden_size, self.hidden_size),
                nn.ReLU(),
                nn.Linear(self.hidden_size, self.hidden_size),
                nn.ReLU(),
                nn.Linear(self.hidden_size, 1),
                nn.ReLU()
            )

    @property
    def device(self):
        return next(self.parameters()).device

    def update_target_network(self):
        soft_update(self.target_feature, self.feature, self.soft_update_rate)

    def update_scale(self):
        self.adv_scale = self.adv_scale + 7 / 20

    def forward(self, state, next_state, action, reward=None):
        encode_state = self.feature(state)
        encode_next_state = self.target_feature(next_state).detach()
        # get pred action
        pred_action = torch.cat((encode_state, encode_next_state), 1)
        pred_action = self.inverse_net(pred_action)
        # ---------------------

        # get pred next state
        pred_next_state_feature_orig = torch.cat((encode_state, action), dim=-1)
        pred_next_state_feature_orig = self.forward_net_1(pred_next_state_feature_orig)

        # residual
        for i in range(self.resnet_time):
            pred_next_state_feature = self.residual[i * 2](torch.cat((pred_next_state_feature_orig, action), dim=-1))
            pred_next_state_feature_orig = self.residual[i * 2 + 1](
                torch.cat((pred_next_state_feature, action), 1)) + pred_next_state_feature_orig

        if self.curiosity:
            pred_next_state_feature = self.forward_net_2(torch.cat((pred_next_state_feature_orig, action), 1))
            real_next_state_feature = encode_next_state
            return real_next_state_feature, pred_next_state_feature, pred_action
        # elif self.density_ratio:
        #     pred_ratio = self.density_fn(torch.cat([state, action, reward.reshape(-1, 1), next_state], dim=-1))
        #     return pred_ratio
        else:
            pred_next_transition = self.transition_proj(pred_next_state_feature_orig)
            pred_next_state, pred_reward = torch.split(pred_next_transition, self.state_size, dim=-1)
            return pred_next_state, pred_reward

    def compute_advantage(self, inputs, mean, std):
        '''
        Use GRPO like advantage
        '''
        score = (inputs.reshape(-1) - mean) / std
        # print(score.mean(), score.std())
        return score

    def compute_reward(self, state, next_state, action, reward, actor, q, target_q):
        state = torch.from_numpy(state).float().to(self.device)
        next_state = torch.from_numpy(next_state).float().to(self.device)
        action = torch.from_numpy(action).float().to(self.device)
        reward = torch.from_numpy(reward).float().to(self.device)

        return self.compute_reward_torch(state, next_state, action, reward, actor, q, target_q)

    def compute_reward_torch(self, state, next_state, action, reward, actor, q, target_q, rew_mean=None, rew_std=None):
        td_loss = torch.tensor(0.0).to(self.device)
        if self.curiosity:
            with torch.no_grad():
                real_next_state_feature, pred_next_state_feature, _ = self.forward(state, next_state, action)
                # icm_reward = F.mse_loss(real_next_state_feature, pred_next_state_feature, reduction='none').mean(1, keepdim=True)
                icm_reward = (real_next_state_feature - pred_next_state_feature).pow(2).mean(1, keepdim=True)

            if self.td:
            # #     with torch.no_grad():
            #         policy_action, _ = actor(next_state)
                # td_loss = (
                #     q(state, action, residual=self.residual)[..., None] - reward.squeeze(1) \
                #         - self.discount * target_q(next_state, policy_action, residual=self.residual)[..., None]).abs()
                # td_loss = q(state, action, online=False) - q(state, action, online=True)
                td_loss = self.compute_advantage(reward, rew_mean, rew_std)
                # # print(q(state, action).shape, reward.shape, td_loss.mean().shape, td_loss.mean(-1).shape)
                # print(td_loss.shape, icm_reward.shape)

        # elif self.density_ratio:
        #     with torch.no_grad():
        #         icm_reward = self.forward(state, next_state, action, reward)
        #     if self.td:
        #         td_loss = self.compute_advantage(reward, rew_mean, rew_std)

        else:
            pred_next_state, pred_next_reward = self.forward(state, next_state, action)
            icm_reward = F.mse_loss(pred_next_state, next_state, reduction='none').mean(1, keepdim=True)
            if self.td:
                # with torch.no_grad():
                #     policy_action, _ = actor(pred_next_state)
                td_loss = self.compute_advantage(reward, rew_mean, rew_std)
                # td_loss = q(state, action, online=False) - q(state, action, online=True)
                # icm_reward = (
                #     q(state, action, residual=self.residual)[..., None] - pred_next_reward.reshape(-1,1) \
                #         - self.discount * target_q(pred_next_state, policy_action, residual=self.residual)[..., None]).abs()
                # print(td_loss.shape, icm_reward.shape)

        # soft_td_scale = 1 / (1 + torch.log(1 + self.adv_scale * td_loss))
        # td_loss = -td_loss
        # icm_reward = torch.exp(-icm_reward)

        ### Min-max Norm
        # icm_reward = (icm_reward - icm_reward.min() + 1e-6) / (icm_reward.max() - icm_reward.min() + 1e-6)

        if self.td:
            # td_loss += torch.randn(*td_loss.shape).to(self.device) * 1e-3
            soft_td_scale = torch.exp(td_loss[..., None] * self.adv_scale).clamp(max=300.0)
            icm_reward *= soft_td_scale

        # icm_reward = soft_td_scale
        # self.update_scale()
        # icm_reward += 0.001 * td_loss
        # icm_reward = torch.exp(3.0 * icm_reward).clamp(max=300.0) * td_loss
        return icm_reward

    def compute_density(self, state, next_state, action, reward, rew_mean, rew_std):
        batch_size = state.shape[0]
        with torch.no_grad():
            pred_ratio = self.density_fn(torch.cat([state, action, reward.reshape(-1, 1), next_state], dim=-1))
            pred_ratio = pred_ratio.reshape(batch_size, 1)
            if self.td:
                adv = self.compute_advantage(reward, rew_mean, rew_std).reshape(batch_size, 1)
                exp_adv = torch.exp(adv * self.adv_scale).clamp(max=300.0)
                pred_ratio *= exp_adv
        return pred_ratio

    def forward_loss(self, state, action, reward, next_state, actor=None, q=None, target_q=None):
        # # from line 58 to 83 of off2on/networks.py
        #     offline_weight = self.density.apply_fn(
        #         {"params": density_params},
        #         offline_minibatch["observations"],
        #         offline_minibatch["actions"]
        #     )

        #     offline_f_star = -jnp.log(2.0 / (offline_weight + 1) + 1e-10)

        #     online_weight = self.density.apply_fn(
        #         {"params": density_params},
        #         online_minibatch["observations"],
        #         online_minibatch["actions"]
        #     )

        #     online_f_prime = jnp.log(2 * online_weight / (online_weight + 1) + 1e-10)

        #     weight_loss = jnp.mean(offline_f_star - online_f_prime)
            
        if self.curiosity:
            real_next_state_feature, pred_next_state_feature, pred_action = self.forward(
                state, next_state, action)
            loss = F.mse_loss(real_next_state_feature.detach(), pred_next_state_feature) + \
                F.mse_loss(action, pred_action)
            return loss
        else:
            pred_next_state, pred_next_reward = self.forward(state, next_state, action)
            loss = F.mse_loss(pred_next_state, next_state) + F.mse_loss(pred_next_reward, reward)
            # if self.td and actor is not None and q is not None and target_q is not None:
            #     with torch.no_grad():
            #         policy_action, _ = actor(pred_next_state)
            #     td_loss = (q(state, action)[..., None] - reward.squeeze(1) - target_q(next_state, policy_action)[..., None]).pow(2)
            #     loss -= self.adv_scale * td_loss
            return loss
        
    def density_loss(self, offline_batch, online_batch):
        off_states, off_actions, off_rewards, off_next_states, off_dones, off_mc_returns = offline_batch
        on_states, on_actions, on_rewards, on_next_states, on_dones, on_mc_returns = online_batch
        offline_weight = self.density_fn(torch.cat([off_states, off_actions, off_rewards.reshape(-1, 1), off_next_states], dim=-1))
        online_weight = self.density_fn(torch.cat([on_states, on_actions, on_rewards.reshape(-1, 1), on_next_states], dim=-1))

        offline_f_star = -torch.log(2.0 / (offline_weight) + 1e-10)
        online_f_prime = torch.log(2.0 * online_weight / (online_weight + 1) + 1e-10)
        weight_loss = torch.mean(offline_f_star - online_f_prime)
        return weight_loss

    def pretrain(self, buffer, optimizer, config, num_epochs=5, steps_per_epoch=10000):
        for i in range(num_epochs):
            avg_loss = 0.0
            for j in range(steps_per_epoch):
                batch = buffer.sample(config.batch_size)
                batch = [b.to(config.device) for b in batch]
                states, actions, rewards, next_states, dones, mc_returns = batch
                loss = self.forward_loss(states, actions, rewards, next_states)
                avg_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                self.update_target_network()

            avg_loss = avg_loss / float(steps_per_epoch)
            print(f"Pretraining MDP model - Epochs {i}/{num_epochs} , loss: {avg_loss:4f}")

    def online_train(self, offline_buffer, optimizer, config, actor, q, target_q,
                     online_buffer=None, num_epochs=1, offline_ratio=0.5):
        sub_bs = config.batch_size // num_epochs
        for i in range(num_epochs):
            avg_loss = 0.0
            if online_buffer is not None:
                offline_bs = int(sub_bs * offline_ratio)
                online_bs = sub_bs - offline_bs
                offline_batch = offline_buffer.sample(offline_bs)
                online_batch = online_buffer.sample(online_bs)
                batch = [torch.cat([off_b, on_b], dim=0) for (
                    off_b, on_b) in zip(offline_batch, online_batch)]
            else:
                batch = offline_buffer.sample(sub_bs)
                batch = [b.to(config.device) for b in batch]
            # if not self.density_ratio:
            states, actions, rewards, next_states, dones, mc_returns = batch
            loss = self.forward_loss(states, actions, rewards, next_states, actor, q, target_q)

            # avg_loss += loss.item()
            # else:
            if self.density_ratio:
                assert online_buffer is not None and offline_buffer is not None, \
                    "Please specificy offline and online buffer respectively."
                loss += self.density_loss(offline_batch, online_batch)
            avg_loss += loss.item()
                # print(f"*** Density fn loss: {avg_loss:4f} ***")

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            self.update_target_network()

            # print(f"Training MDP model - Epochs {i}/{num_epochs} , loss: {avg_loss:4f}")


class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: Union[List[int], Tuple[int]],
        output_dim: Optional[int] = None,
        activation: nn.Module = nn.ReLU,
        dropout_rate: Optional[float] = None
    ) -> None:
        super().__init__()
        hidden_dims = [input_dim] + [256] * hidden_dims
        model = []
        for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]):
            model += [nn.Linear(in_dim, out_dim), activation()]
            if dropout_rate is not None:
                model += [nn.Dropout(p=dropout_rate)]

        self.output_dim = hidden_dims[-1]
        if output_dim is not None:
            model += [nn.Linear(hidden_dims[-1], output_dim)]
            self.output_dim = output_dim
        self.model = nn.Sequential(*model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


class ResNetPreActivationLayer(nn.Module):
    def __init__(
            self,
            input_output_dim,
            hidden_dim,
            activation=nn.ReLU,
            normalization=nn.LayerNorm,
            dropout=0.1,
            ):
        super().__init__()

        self.input_dim = self.output_dim = input_output_dim
        self.hidden_dim = hidden_dim

        self.norm_layer_1 = normalization(input_output_dim)
        self.activation_layer_1 = activation()
        self.dropout_1 = nn.Dropout(dropout) if dropout else None
        self.linear_layer_1 = nn.Linear(input_output_dim, hidden_dim)

        self.norm_layer_2 = normalization(hidden_dim)
        self.activation_layer_2 = activation()
        self.dropout_2 = nn.Dropout(dropout) if dropout else None
        self.linear_layer_2 = nn.Linear(hidden_dim, input_output_dim)

    def forward(self, x):
        y = self.norm_layer_1(x)
        y = self.activation_layer_1(y)
        if self.dropout_1 is not None:
            y = self.dropout_1(y)
        y = self.linear_layer_1(y)

        y = self.norm_layer_2(y)
        y = self.activation_layer_2(y)
        if self.dropout_2 is not None:
            y = self.dropout_2(y)
        y = self.linear_layer_2(y)
        return x + y


class ResNetPreActivation(nn.Module):
    def __init__(
            self,
            input_dim,
            out_dim,
            res_dim,
            res_hidden_dim,
            n_res_layers,
            activation=nn.ReLU,
            normalization=nn.LayerNorm,
            dropout=0.1,
            device="cpu",
            ):
        super().__init__()
        self.device = torch.device(device)
        self.projection_layer = nn.Linear(input_dim, res_dim)
        self.projection_output = nn.Linear(res_dim, out_dim)

        module_list = [ self.projection_layer ]

        for l in range(n_res_layers):
            module_list.append(ResNetPreActivationLayer(res_dim, res_hidden_dim, activation, normalization, dropout))

        module_list.append(self.projection_output)
        self.backbones = nn.ModuleList(module_list)
        self.to(self.device)

    def forward(self, x):
        for layer in self.backbones:
            x = layer(x)
        return x


# p = float(train_cnt.count + epoch * len(train_dl)) / (epochs *len(train_dl))
# alpha = torch.tensor(2. / (1. + np.exp(-10 * p)) - 1)
class GRL(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None
