import copy
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

import os
from .actor import Actor
from .critic import Critic, ValueCritic


class hyper(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(hyper, self).__init__()

        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)

    def forward(self, state, action, adv):
        sa = torch.cat([state, action], 1)

        w = F.relu(self.l1(sa))
        w = F.relu(self.l2(w))
        # adv as init point
        w = self.l3(w) + adv

        return w


def loss(diff, expectile=0.8):
    weight = torch.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)



class VACO(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        expectile,
        discount,
        tau,
        temperature,
        group_size,
        device
    ):

        self.actor = Actor(state_dim, action_dim, 256, 3).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.actor_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.actor_optimizer, T_max=int(1e6))

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.value = ValueCritic(state_dim, 256, 3).to(device)
        self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=3e-4)

        self.meta_w = hyper(state_dim, action_dim).to(device)
        self.meta_optimizer = torch.optim.Adam(self.meta_w.parameters(), lr=1e-4)

        self.discount = discount
        self.tau = tau
        self.temperature = temperature
        self.group_size = group_size
        self.total_it = 0
        self.expectile = expectile
        self.device = device
        self.baseline = 0

    def update_v(self, states, actions):
        with torch.no_grad():
            q1, q2 = self.critic_target(states, actions)
            q = torch.minimum(q1, q2).detach()

        v = self.value(states)
        value_loss = loss(q - v, self.expectile).mean()

        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()

    def update_q(self, states, actions, rewards, next_states, not_dones):
        with torch.no_grad():
            next_v = self.value(next_states)
            target_q = (rewards + self.discount * not_dones * next_v).detach()

        q1, q2 = self.critic(states, actions)
        critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

    def update_target(self):
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def update_actor(self, states, actions):
        with torch.no_grad():
            v = self.value(states)
            q1, q2 = self.critic_target(states, actions)
            q = torch.minimum(q1, q2)
            exp_a = torch.exp((q - v) * (self.temperature))
            exp_a = torch.clamp(exp_a, max=100.0).squeeze(-1).detach()


        mu = self.actor(states)
        actor_loss = (exp_a.unsqueeze(-1) * ((mu - actions)**2)).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        self.actor_scheduler.step()

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        return self.actor.get_action(state).cpu().data.numpy().flatten()


    def train(self, get_batch, batch_size=256):
        self.total_it += 1
        # Sample replay buffer
        state, action, next_state, reward, not_done = get_batch(batch_size)

        # Update
        self.update_v(state, action)
        self.update_actor(state, action)
        self.update_q(state, action, reward, next_state, not_done)
        self.update_target()

    def train_ori(self, get_batch, batch_size=256):
        self.total_it += 1
        # Sample replay buffer
        state, action, next_state, reward, not_done = get_batch(batch_size)

        # Update
        self.update_actor(state, action)

    def train_critic(self, get_batch, batch_size=256):
        self.total_it += 1
        # Sample replay buffer
        state, action, next_state, reward, not_done = get_batch(batch_size)

        # Update
        self.update_v(state, action)
        self.update_q(state, action, reward, next_state, not_done)
        self.update_target()


    def train_actor(self, get_batch, batch_size=256):
        self.total_it += 1
        # Sample replay buffer
        state, action, next_state, reward, not_done = get_batch(batch_size)

        self.update_actor_bi(state, action)

    def update_actor_bi(self, states, actions):
        with torch.no_grad():
            v = self.value(states)
            q1, q2 = self.critic_target(states, actions)
            q = torch.minimum(q1, q2)

        exp_a = self.meta_w(states, actions, (q-v).detach())

        mu = self.actor(states)
        actor_loss = ((mu - actions)**2)

        batch_size = states.shape[0]
        dt_bc_grad_list = []
        bs = int(batch_size / self.group_size)
        for k in range(self.group_size):
            loss_sum = torch.mean(actor_loss[k*bs:(k+1)*bs, :])
            dt_td_grad = torch.autograd.grad(loss_sum, self.actor.parameters(), retain_graph=True)
            dt_bc_grad_list.append(torch.cat([grad.view(-1) for grad in dt_td_grad], dim=0))

        actor_loss = torch.mean(actor_loss * exp_a.detach())

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        random_noise = (torch.randn(states.size())).to(self.device) * (1 - self.total_it / 1000000 * 1) * 0.5
        mu_new = self.actor(states + random_noise)
        q1, q2 = self.critic_target(states, mu_new)
        q = torch.minimum(q1, q2).mean()
        dt_val_loss_grad = torch.autograd.grad(q, self.actor.parameters())
        dt_val_loss_grad = torch.cat([grad.view(-1) for grad in dt_val_loss_grad], dim=0).detach()

        meta_loss = torch.tensor([0.]).to(self.device)
        r_tensor = torch.zeros([self.group_size]).to(self.device)
        for k in range(self.group_size):
            r_tensor[k] = torch.sum(dt_val_loss_grad * dt_bc_grad_list[k]) / 100
            meta_loss = meta_loss + torch.mean(exp_a[k*bs:(k+1)*bs] * (r_tensor[k] - self.baseline).detach())
        meta_loss = meta_loss / batch_size

        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        self.meta_optimizer.step()

        if self.baseline == 0:
            self.baseline = torch.mean(r_tensor)
        else:
            self.baseline = self.baseline - 0.1 * (self.baseline - torch.mean(r_tensor))

    def save_critic(self, model_dir):
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        torch.save(self.critic.state_dict(), os.path.join(model_dir, f"critic_bis{str(self.total_it)}.pth"))
        torch.save(self.critic_target.state_dict(), os.path.join(model_dir, f"critic_target_bis{str(self.total_it)}.pth"))
        torch.save(self.value.state_dict(), os.path.join(model_dir, f"value_bis{str(self.total_it)}.pth"))

    def save_actor(self, model_dir):
        torch.save(self.actor.state_dict(), os.path.join(model_dir, f"actor_bis{str(self.total_it)}.pth"))
        torch.save(self.meta_w.state_dict(), os.path.join(model_dir, f"meta_bis{str(self.total_it)}.pth"))

    def save_ori(self, model_dir):
        torch.save(self.actor.state_dict(), os.path.join(model_dir, f"actor_ls{str(self.total_it)}.pth"))

    def save(self, model_dir):
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        torch.save(self.critic.state_dict(), os.path.join(model_dir, f"critic_s{str(self.total_it)}.pth"))
        torch.save(self.critic_target.state_dict(), os.path.join(model_dir, f"critic_target_s{str(self.total_it)}.pth"))
        torch.save(self.critic_optimizer.state_dict(), os.path.join(
            model_dir, f"critic_optimizer_s{str(self.total_it)}.pth"))

        torch.save(self.actor.state_dict(), os.path.join(model_dir, f"actor_s{str(self.total_it)}.pth"))
        torch.save(self.actor_optimizer.state_dict(), os.path.join(
            model_dir, f"actor_optimizer_s{str(self.total_it)}.pth"))
        torch.save(self.actor_scheduler.state_dict(), os.path.join(
            model_dir, f"actor_scheduler_s{str(self.total_it)}.pth"))

        torch.save(self.value.state_dict(), os.path.join(model_dir, f"value_s{str(self.total_it)}.pth"))
        torch.save(self.value_optimizer.state_dict(), os.path.join(
            model_dir, f"value_optimizer_s{str(self.total_it)}.pth"))
