import copy
import os
import sys
from typing import Callable, Optional

from torch.nn.modules.dropout import Dropout
import torch
import torch.distributions
import torch.nn.functional as F
from torch import nn

def loss(diff, expectile=0.8):
    weight = torch.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)

class MLP(nn.Module):

    def __init__(
        self,
        in_dim,
        out_dim,
        hidden_dim,
        n_layers,
        activations: Callable = nn.ReLU,
        activate_final: int = False,
        dropout_rate: Optional[float] = None
    ) -> None:
        super().__init__()

        self.affines = []
        self.affines.append(nn.Linear(in_dim, hidden_dim))
        for i in range(n_layers-2):
            self.affines.append(nn.Linear(hidden_dim, hidden_dim))
        self.affines.append(nn.Linear(hidden_dim, out_dim))
        self.affines = nn.ModuleList(self.affines)

        self.activations = activations()
        self.activate_final = activate_final
        self.dropout_rate = dropout_rate
        if dropout_rate is not None:
            self.dropout = Dropout(self.dropout_rate)

    def forward(self, x):
        for i in range(len(self.affines)):
            x = self.affines[i](x)
            if i != len(self.affines)-1 or self.activate_final:
                x = self.activations(x)
                if self.dropout_rate is not None:
                    x = self.dropout(x)
        return x
class Actor(nn.Module):
    """MLP actor network."""

    def __init__(
        self, state_dim, action_dim, hidden_dim, n_layers, dropout_rate=None,
        log_std_min=-10.0, log_std_max=2.0,
    ):
        super().__init__()

        self.mlp = MLP(
            state_dim, 2 * action_dim, hidden_dim, n_layers, dropout_rate=dropout_rate
        )

        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

    def forward(
        self, states
    ):
        mu, log_std = self.mlp(states).chunk(2, dim=-1)
        mu = torch.tanh(mu)
        return mu

    def get_action(self, states):
        mu = self.forward(states)
        return mu


class ValueCritic(nn.Module):
    def __init__(
        self,
        in_dim,
        hidden_dim,
        n_layers,
        **kwargs
    ) -> None:
        super().__init__()
        self.mlp = MLP(in_dim, 1, hidden_dim, n_layers, **kwargs)

    def forward(self, state):
        return self.mlp(state)


class Critic(nn.Module):
    """
    From TD3+BC
    """

    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1


class IQL(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        expectile,
        discount,
        tau,
        temperature,
        device
    ):
        self.device = device

        self.actor = Actor(state_dim, action_dim, 256, 3).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.actor_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.actor_optimizer, T_max=int(1e6))

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.value = ValueCritic(state_dim, 256, 3).to(device)
        self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=3e-4)

        self.discount = discount
        self.tau = tau
        self.temperature = temperature

        self.total_it = 0
        self.expectile = expectile

    def update_v(self, states, actions):
        with torch.no_grad():
            q1, q2 = self.critic_target(states, actions)
            q = torch.minimum(q1, q2).detach()

        v = self.value(states)
        value_loss = loss(q - v, self.expectile).mean()

        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()

    def update_q(self, states, actions, rewards, next_states, not_dones):
        with torch.no_grad():
            next_v = self.value(next_states)
            target_q = (rewards + self.discount * not_dones * next_v).detach()

        q1, q2 = self.critic(states, actions)
        critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

    def update_target(self):
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def update_actor(self, states, actions, ):
        with torch.no_grad():
            v = self.value(states)
            q1, q2 = self.critic_target(states, actions)
            q = torch.minimum(q1, q2)
            exp_a = torch.exp((q - v) * self.temperature)
            exp_a = torch.clamp(exp_a, max=100.0).squeeze(-1).detach()

        mu = self.actor(states)
        actor_loss = (exp_a.unsqueeze(-1) * ((mu - actions)**2)).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        self.actor_scheduler.step()

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        return self.actor.get_action(state).cpu().data.numpy().flatten()

    def train(self, replay_buffer, batch_size=256):
        self.total_it += 1

        # Sample replay buffer
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

        # Update
        self.update_v(state, action)
        self.update_actor(state, action)
        self.update_q(state, action, reward, next_state, not_done)
        self.update_target()

    def save(self, model_dir):
        torch.save(self.critic.state_dict(), os.path.join(model_dir, f"critic_s{str(self.total_it)}.pth"))
        torch.save(self.critic_target.state_dict(), os.path.join(model_dir, f"critic_target_s{str(self.total_it)}.pth"))
        torch.save(self.critic_optimizer.state_dict(), os.path.join(
            model_dir, f"critic_optimizer_s{str(self.total_it)}.pth"))

        torch.save(self.actor.state_dict(), os.path.join(model_dir, f"actor_s{str(self.total_it)}.pth"))
        torch.save(self.actor_optimizer.state_dict(), os.path.join(
            model_dir, f"actor_optimizer_s{str(self.total_it)}.pth"))
        torch.save(self.actor_scheduler.state_dict(), os.path.join(
            model_dir, f"actor_scheduler_s{str(self.total_it)}.pth"))

        torch.save(self.value.state_dict(), os.path.join(model_dir, f"value_s{str(self.total_it)}.pth"))
        torch.save(self.value_optimizer.state_dict(), os.path.join(
            model_dir, f"value_optimizer_s{str(self.total_it)}.pth"))


    def save(self, path):
        torch.save(self.critic.state_dict(), os.path.join(path, 'policy_critic.pkl'))
        torch.save(self.critic_target.state_dict(), os.path.join(path, "policy_critic_target.pkl"))
        torch.save(self.critic_optimizer.state_dict(), os.path.join(path, "policy_critic_optimizer.pkl"))

        torch.save(self.actor.state_dict(), os.path.join(path, "policy_actor.pkl"))
        torch.save(self.actor_optimizer.state_dict(), os.path.join(path, "policy_actor_optimizer.pkl"))
        torch.save(self.actor_scheduler.state_dict(), os.path.join(path, "policy_actor_scheduler.pkl"))

        torch.save(self.value.state_dict(), os.path.join(path, "policy_value.pkl"))
        torch.save(self.value_optimizer.state_dict(), os.path.join(path, "policy_value_optimizer.pkl"))


    def load(self, path):
        self.critic.load_state_dict(torch.load(os.path.join(path, "policy_critic.pkl")))
        self.critic_target.load_state_dict(torch.load(os.path.join(path, "policy_critic_target.pkl")))
        self.critic_optimizer.load_state_dict(torch.load(os.path.join(path, "policy_critic_optimizer.pkl")))

        self.actor.load_state_dict(torch.load(os.path.join(path, "policy_actor.pkl")))
        self.actor_optimizer.load_state_dict(torch.load(os.path.join(path, "policy_actor_optimizer.pkl")))
        self.actor_scheduler.load_state_dict(torch.load(os.path.join(path, "policy_actor_scheduler.pkl")))

        self.value.load_state_dict(torch.load(os.path.join(path, "policy_value.pkl")))
        self.value_optimizer.load_state_dict(torch.load(os.path.join(path, "policy_value_optimizer.pkl")))
