import torch
from torch import nn
from torch.distributions import Normal
import numpy as np

from network import RNDModel, DRNDModel, CFNModel


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class PPOAgent(nn.Module):
    def __init__(self, envs, args):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)


class RNDAgent(nn.Module):
    def __init__(self, envs, args):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
        self.rnd = RNDModel(np.array(envs.single_observation_space.shape).prod(), 64)

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
    
    def get_rnd_loss(self, obs):
        pred_feature, target_feature = self.rnd(obs)
        return (pred_feature - target_feature.detach()).pow(2).mean(-1)


class DRNDAgent(nn.Module):
    def __init__(self, envs, args):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
        self.drnd = DRNDModel(np.array(envs.single_observation_space.shape).prod(), 64)

        self.alpha = args.alpha

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
    
    def get_rnd_loss(self, obs):
        pred_feature, target_feature = self.drnd(obs)
        idx = np.random.randint(self.drnd.num_target, size=(obs.shape[0]))
        return (pred_feature - target_feature[idx,np.arange(obs.shape[0]),:]).pow(2).mean(1)

    def get_rnd_reward(self, obs):
        predict_next_feature,target_next_feature = self.drnd(obs)
        mu = torch.mean(target_next_feature,axis=0)
        B2 = torch.mean(target_next_feature**2,axis=0)
        intrinsic_reward = self.alpha*(predict_next_feature - mu).pow(2).sum(1)+ (1-self.alpha)*torch.mean(torch.sqrt(torch.clip(abs(predict_next_feature ** 2 - mu ** 2) / (B2 - mu ** 2),1e-3,1)), axis=1)
        return intrinsic_reward
    
    def to(self, device):
        super().to(device)
        for t_net in self.drnd.target:
            t_net.to(device)
        return self


class CFNAgent(nn.Module):
    def __init__(self, envs, args):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
        self.cfn = CFNModel(np.array(envs.single_observation_space.shape).prod(), 64)

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
    
    def get_rnd_loss(self, obs):
        pred_feature = self.cfn(obs)
        target_feature = (torch.randint(0, 2, pred_feature.shape, dtype=torch.int)*2-1).float()
        return (pred_feature - target_feature.to(pred_feature.device)).pow(2).mean(-1)

    def get_rnd_reward(self, obs):
        predict = self.cfn(obs)
        B2 = torch.square(predict).mean(-1)
        intrinsic_reward = torch.sqrt(B2)
        return intrinsic_reward


Agent = {}
Agent['PPO'] = PPOAgent
Agent['ppo'] = PPOAgent
Agent['RND'] = RNDAgent
Agent['rnd'] = RNDAgent
Agent['DRND'] = DRNDAgent
Agent['drnd'] = DRNDAgent
Agent['CFN'] = CFNAgent
Agent['cfn'] = CFNAgent