import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class GaussianPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(GaussianPolicy, self).__init__()

        self.linear1 = nn.Linear(state_dim, 256)
        self.linear2 = nn.Linear(256, 256)

        self.mean_linear = nn.Linear(256, action_dim)
        self.log_std_linear = nn.Linear(256, action_dim)

        self.max_action = max_action

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=-20, max=2)
        return mean, log_std

    def sample(self, state, deterministic=False):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            action = mean
        else:
            action = normal.rsample()
        log_prob = normal.log_prob(action).sum(axis=-1)
        log_prob -= (2*(np.log(2) - action - F.softplus(-2*action))).sum(axis=-1)

        action = torch.tanh(action) * self.max_action
        return action, log_prob

class Mlp(nn.Module):
    def __init__(
            self,
            input_size,
            hidden_sizes,
            output_size
    ):
        super().__init__()
        # TODO: initialization
        self.fcs = []
        in_size = input_size
        for i, next_size in enumerate(hidden_sizes):
            fc = nn.Linear(in_size, next_size)
            self.add_module(f'fc{i}', fc)
            self.fcs.append(fc)
            in_size = next_size
        self.last_fc = nn.Linear(in_size, output_size)

    def forward(self, input):
        h = input
        for fc in self.fcs:
            h = F.relu(fc(h))
        output = self.last_fc(h)
        return output

class QNetwotk(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwotk, self).__init__()

        self.nets = []

        for i in range(5):
            net = Mlp(state_dim + action_dim, [512, 512, 512], 25)
            self.add_module(f'qf{i}', net)
            self.nets.append(net)

    def forward(self, state, action):
        sa = torch.cat((state, action), dim=1)
        quantiles = torch.stack(tuple(net(sa) for net in self.nets), dim=1)
        return quantiles  # batch_size, n_nets, n_quantiles=512,5,25

class TQC(object):
    def __init__(
            self,
            state_dim,
            action_dim,
            max_action,
            discount=0.99,
            tau=0.005,
    ):
        self.device = device
        self.actor = GaussianPolicy(state_dim, action_dim, max_action).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = QNetwotk(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau

        self.alpha = 0.2

        self.total_it = 0
        #self.start_time = 0
    def select_action(self, state, deterministic=False):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        action, _ = self.actor.sample(state, deterministic)
        return action.cpu().data.numpy().flatten()

    def train(self, replay_buffer, batch_size=256):
        # time
        '''if self.total_it == 0:
            self.start_time = time.time()
        if self.total_it == 100000:
            end_time = time.time()
            execution_time = end_time - self.start_time
            with open('time.txt', 'a') as f:
                print(execution_time)
                f.write(str(execution_time))
                f.write(',')'''

        self.total_it += 1

        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

        with torch.no_grad():
            next_action, next_state_log_pi = self.actor_target.sample(next_state)
            target_z = self.critic_target(next_state, next_action)
            sorted_z, _ = torch.sort(target_z.reshape(batch_size, -1))
            sorted_z_part = sorted_z[:, 0:115]
            next_state_log_pi = torch.unsqueeze(next_state_log_pi, dim=-1)
            target_z = self.discount * not_done * (sorted_z_part - self.alpha * next_state_log_pi) + reward

        current_z = self.critic(state, action)  # torch.Size([256, 25])
        tau1 = torch.arange(25, device=self.device).float() / 25 + 1 / 2 / 25  # tau1 = i/n_quantiles + 1/2/n_quantiles

        # Compute 1-W loss
        pairwise_delta = target_z[:, None,None, :] - current_z[:, :, :, None]  # batch*25*23
        abs_pairwise_delta = torch.abs(pairwise_delta)
        huber_loss = torch.where(abs_pairwise_delta > 1,
                                 abs_pairwise_delta - 0.5,
                                 pairwise_delta ** 2 * 0.5)
        critic_loss = (torch.abs(tau1[None, None, :, None] - (pairwise_delta > 0).float()) * huber_loss).mean()

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Compute actor losse
        action, log_pi = self.actor.sample(state)

        actor_loss = (self.alpha * log_pi - self.critic(state, action).mean(2).mean(1,keepdim=True)).mean()

        # Optimize the actor
        # print(actor_loss)
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Update the frozen target models
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")

        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
        self.actor_target = copy.deepcopy(self.actor)