import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Tuple

from utils import RunningMeanStd


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3)
# Paper: https://arxiv.org/abs/1802.09477


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.state_dim = state_dim
        self.l1 = nn.Linear(state_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, action_dim)
        
        self.max_action = max_action
        

    def forward(self, state):
        # print(state.shape, self.state_dim)
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.tanh(self.l3(a))


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)


    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2


    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1

class ICMModel(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, state_dim),
        )

        # predict next_state
        self.forward_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, state_dim)
        )

        # predict action
        self.inverse_net = nn.Sequential(
            nn.Linear(state_dim + state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, state, action, next_state):
        assert state.shape[0] == next_state.shape[0]
        assert state.shape[0] == action.shape[0]

        batch_size = state.shape[0]
        all_states = torch.cat([state, next_state], dim=0)
        encoded_all_states = self.phi(all_states)

        state, next_state = encoded_all_states.split(batch_size, dim=0)

        next_state_hat = self.forward_net(torch.cat([state, action], dim=-1))

        forward_error = F.mse_loss(next_state_hat, next_state)

        action_hat = self.inverse_net(torch.cat([state, next_state], dim=-1))
        inverse_error = F.mse_loss(action_hat, action)

        return forward_error + inverse_error

    def get_uncertainty(self, state, action, next_state):
        assert state.shape[0] == next_state.shape[0]
        assert state.shape[0] == action.shape[0]

        batch_size = state.shape[0]
        all_states = torch.cat([state, next_state], dim=0)
        encoded_all_states = self.phi(all_states)

        state, next_state = encoded_all_states.split(batch_size, dim=0)

        next_state_hat = self.forward_net(torch.cat([state, action], dim=-1))

        return (next_state_hat - next_state).pow(2).sum(1) / 2


class DynamicsModel(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DynamicsModel, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, state_dim)

    def forward(self, state, action, next_state):
        sa = torch.cat([state, action], 1)
        out = F.relu(self.l1(sa))
        out = F.relu(self.l2(out))
        out = self.l3(out)

        return F.mse_loss(out, next_state)
    
    def get_uncertainty(self, state, action, next_state):
        sa = torch.cat([state, action], 1)
        out = F.relu(self.l1(sa))
        out = F.relu(self.l2(out))
        out = self.l3(out)

        return (out - next_state).pow(2).sum(1) / 2

class DisagreementModel(nn.Module):
    def __init__(self, state_dim, action_dim, n_models=5):
        super().__init__()
        self.ensemble = nn.ModuleList([
            nn.Sequential(nn.Linear(state_dim + action_dim, 256),
                          nn.ReLU(), nn.Linear(256, state_dim))
            for _ in range(n_models)
        ])

    def forward(self, obs, action, next_obs):
        #import ipdb; ipdb.set_trace()
        assert obs.shape[0] == next_obs.shape[0]
        assert obs.shape[0] == action.shape[0]

        errors = 0
        for model in self.ensemble:
            next_obs_hat = model(torch.cat([obs, action], dim=-1))
            model_error = F.mse_loss(next_obs_hat, next_obs)
            errors += model_error

        return errors

    def get_uncertainty(self, obs, action, next_obs):
        assert obs.shape[0] == next_obs.shape[0]
        assert obs.shape[0] == action.shape[0]

        preds = []
        for model in self.ensemble:
            next_obs_hat = model(torch.cat([obs, action], dim=-1))
            preds.append(next_obs_hat)
        preds = torch.stack(preds, dim=0)
        return torch.var(preds, dim=0).mean(dim=-1).unsqueeze(1)

class TD3(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        discount=0.99,
        tau=0.005,
        policy_noise=0.2,
        noise_clip=0.5,
        policy_freq=2,
        explore=False,
        lam=1,
        bonus_type='icm',
    ):  
        assert bonus_type in ['icm', 'dynamics', 'disagreement']

        print('Initializing TD3 with Optiional Dynamics Module')
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.total_it = 0

        self.explore = explore
        if self.explore:
            if bonus_type == 'icm':
                self.int_model = ICMModel(state_dim, action_dim).to(device)
            elif bonus_type == 'dynamics':
                self.int_model = DynamicsModel(state_dim, action_dim).to(device)
            elif bonus_type == 'disagreement':
                self.int_model = DisagreementModel(state_dim, action_dim).to(device)
            else:
                raise NotImplementedError
            
            self.int_optimizer = torch.optim.Adam(self.int_model.parameters(), lr=1e-4)
            self.lam = lam
            self.bonus_type = bonus_type

        self.Q_loss = [] # deviation of Q estimation
        self.actor_loss = [] # estimation of current performance of policy
        self.width_value = []

    @torch.no_grad()
    def get_int_reward(self, state, action, next_state):
        assert self.explore, "Not in exploration mode"
        if self.bonus_type == 'icm':
            return self.int_model.get_uncertainty(state, action, next_state)
        elif self.bonus_type == 'dynamics':
            return self.int_model.get_uncertainty(state, action, next_state)

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()
        

    def train(self, replay_buffer, batch_size=256):
        self.total_it += 1

        # Sample replay buffer 
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (
                torch.randn_like(action) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip)
            
            next_action = (
                self.actor_target(next_state) + noise
            ).clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * target_Q

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)

        # Get intrinsic bonus
        if self.explore:
            with torch.no_grad():
                int_reward = self.int_model.get_uncertainty(state, action, next_state)
                int_reward = int_reward * self.lam
                self.width_value.append(int_reward.mean().cpu().data.item())
                target_Q = target_Q + int_reward.reshape(256, 1)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:

            # Compute actor losse
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
            
            # Optimize the actor 
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
            self.Q_loss.append(critic_loss.cpu().data.item())
            self.actor_loss.append(actor_loss.cpu().data.item())

        # Update int_model bonus module
        if self.explore and np.random.rand() < 0.25:
            loss = self.int_model(state, action, next_state)
            self.int_optimizer.zero_grad()
            loss.backward()
            self.int_optimizer.step()


    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
        
        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")


    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
        self.actor_target = copy.deepcopy(self.actor)
        