import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils.logger import logger

from agents.diffusion import Diffusion
from agents.model import MLP
from agents.helpers import EMA


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.q1_model = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )

        self.q2_model = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x), self.q2_model(x)

    def q1(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x)

    def q_min(self, state, action):
        q1, q2 = self.forward(state, action)
        return torch.min(q1, q2)


class DiffCPS(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        device,
        discount,
        tau,
        max_q_backup=False,
        LA=1.0,
        beta_schedule="linear",
        n_timesteps=100,
        ema_decay=0.995,
        step_start_ema=1000,
        update_ema_every=5,
        lr=3e-4,
        lr_decay=False,
        lr_maxt=1000,
        grad_norm=1.0,
        # policy_noise=0.2,
        # noise_clip=0.1,
        policy_freq=10,
        target_kl=0.05,
        lambda_max=100,
        lambda_min=0.5,
    ):
        self.model = MLP(state_dim=state_dim, action_dim=action_dim, device=device)

        self.actor = Diffusion(
            state_dim=state_dim,
            action_dim=action_dim,
            model=self.model,
            max_action=max_action,
            beta_schedule=beta_schedule,
            n_timesteps=n_timesteps,
        ).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)

        self.lr_decay = lr_decay
        self.grad_norm = grad_norm

        self.step = 0
        self.step_start_ema = step_start_ema
        self.ema = EMA(ema_decay)
        self.ema_model = copy.deepcopy(self.actor)
        self.update_ema_every = update_ema_every
        # self.policy_noise = policy_noise
        # self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
        self.LA = torch.tensor(LA, dtype=torch.float).to(device)  # kl weight
        self.lambda_min = lambda_min
        self.lambda_max = lambda_max

        self.LA.requires_grad = True
        self.LA_optimizer = torch.optim.Adam([self.LA], lr=3e-2)

        if lr_decay:
            self.actor_lr_scheduler = CosineAnnealingLR(
                self.actor_optimizer, T_max=lr_maxt, eta_min=0.0
            )
            self.critic_lr_scheduler = CosineAnnealingLR(
                self.critic_optimizer, T_max=lr_maxt, eta_min=0.0
            )
            self.lambda_lr_scheduler = CosineAnnealingLR(
                self.LA_optimizer, T_max=lr_maxt, eta_min=0.0
            )

        self.state_dim = state_dim
        self.max_action = max_action
        self.action_dim = action_dim
        self.discount = discount
        self.tau = tau

        self.target_kl = target_kl
        self.device = device
        self.max_q_backup = max_q_backup

    def step_ema(self):
        if self.step < self.step_start_ema:
            return
        self.ema.update_model_average(self.ema_model, self.actor)

    def train(self, replay_buffer, iterations, batch_size=100, log_writer=None):
        metric = {
            "kl_loss": [],
            "ql_loss": [],
            "actor_loss": [],
            "critic_loss": [],
            "lambda": [],
        }

        for _ in range(iterations):
            state, action, reward = replay_buffer.sample(batch_size)

            """ Q Training """
            current_q1, current_q2 = self.critic(state, action)

            critic_loss = F.mse_loss(current_q1, reward) + F.mse_loss(
                current_q2, reward
            )

            self.critic_optimizer.zero_grad()
            critic_loss.backward()

            self.critic_optimizer.step()

            # training policy every policy_freq steps

            if self.step % self.policy_freq == 0:
                """Policy Training"""
                # print(state.shape)
                kl_loss = self.actor.loss(action, state)
                new_action = self.actor(state)

                q1_new_action, q2_new_action = self.critic(state, new_action)
                if np.random.uniform() > 0.5:
                    q_loss = -q1_new_action.mean() / q2_new_action.abs().mean().detach()
                else:
                    q_loss = -q2_new_action.mean() / q1_new_action.abs().mean().detach()
                # q_loss = - q1_new_action.mean()
                actor_loss = (
                    self.LA.clamp(self.lambda_min, self.lambda_max).detach() * kl_loss
                    + q_loss
                )

                self.actor_optimizer.zero_grad()
                actor_loss.backward()

                self.actor_optimizer.step()

                """ Lambda loss"""

                lambda_loss = (self.target_kl - kl_loss).detach() * self.LA
                self.LA_optimizer.zero_grad()
                lambda_loss.backward()

                self.LA_optimizer.step()

                metric["actor_loss"].append(actor_loss.item())
                metric["kl_loss"].append(kl_loss.item())
                metric["ql_loss"].append(q_loss.item())
                metric["critic_loss"].append(critic_loss.item())
                metric["lambda"].append(self.LA.clamp(self.lambda_min, self.lambda_max).item())

            """ Step Target network """
            if self.step % self.update_ema_every == 0:
                self.step_ema()

            for param, target_param in zip(
                self.critic.parameters(), self.critic_target.parameters()
            ):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data
                )

            self.step += 1

        if self.lr_decay:
            self.actor_lr_scheduler.step()
            self.critic_lr_scheduler.step()
            self.lambda_lr_scheduler.step()

        return metric

    def sample_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        # print(state.shape)
        state_rpt = torch.repeat_interleave(state, repeats=50, dim=0)
        # print(state_rpt.shape)
        with torch.no_grad():
            action = self.actor.sample(state_rpt)

            q_value = self.critic_target.q_min(state_rpt, action).flatten()
            idx = torch.multinomial(F.softmax(q_value), 1)

            """
            Returns a tensor where each row contains num_samples indices sampled from the multinomial 
            probability distribution located in the corresponding row of tensor input.
            """
        return action[idx].cpu().data.numpy().flatten()

    def save_model(self, dir, id=None):
        if id is not None:
            torch.save(self.actor.state_dict(), f"{dir}/actor_{id}.pth")
            torch.save(self.critic.state_dict(), f"{dir}/critic_{id}.pth")
        else:
            torch.save(self.actor.state_dict(), f"{dir}/actor.pth")
            torch.save(self.critic.state_dict(), f"{dir}/critic.pth")

    def load_model(self, dir, id=None):
        if id is not None:
            self.actor.load_state_dict(torch.load(f"{dir}/actor_{id}.pth"))
            self.critic.load_state_dict(torch.load(f"{dir}/critic_{id}.pth"))
        else:
            self.actor.load_state_dict(torch.load(f"{dir}/actor.pth"))
            self.critic.load_state_dict(torch.load(f"{dir}/critic.pth"))
