import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils.logger import logger

from agents.diffusion import Diffusion
from agents.model import MLP
from agents.helpers import EMA


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.q1_model = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )

        self.q2_model = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x), self.q2_model(x)

    def q1(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x)

    def q_min(self, state, action):
        q1, q2 = self.forward(state, action)
        return torch.min(q1, q2)


class DiffCPS(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        device,
        discount,
        tau,
        max_q_backup=False,
        LA=1.0,
        beta_schedule="linear",
        n_timesteps=100,
        ema_decay=0.995,
        step_start_ema=1000,
        update_ema_every=5,
        lr=3e-4,
        lr_decay=False,
        lr_maxt=1000,
        grad_norm=1.0,
        # policy_noise=0.2,
        # noise_clip=0.1,
        policy_freq=10,
        target_kl=0.05,
        LA_max=100,
        LA_min=0,
    ):
        self.model = MLP(state_dim=state_dim, action_dim=action_dim, device=device)

        self.actor = Diffusion(
            state_dim=state_dim,
            action_dim=action_dim,
            model=self.model,
            max_action=max_action,
            beta_schedule=beta_schedule,
            n_timesteps=n_timesteps,
        ).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)

        self.lr_decay = lr_decay
        self.grad_norm = grad_norm

        self.step = 0
        self.step_start_ema = step_start_ema
        self.ema = EMA(ema_decay)
        self.ema_model = copy.deepcopy(self.actor)
        self.update_ema_every = update_ema_every
        # self.policy_noise = policy_noise
        # self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
        self.LA = torch.tensor(LA, dtype=torch.float).to(device)  # Lambda
        self.LA_min = LA_min
        self.LA_max = LA_max

        self.LA.requires_grad = True
        self.LA_optimizer = torch.optim.Adam([self.LA], lr=3e-5)

        if lr_decay:
            self.actor_lr_scheduler = CosineAnnealingLR(
                self.actor_optimizer, T_max=lr_maxt, eta_min=0.0
            )
            self.critic_lr_scheduler = CosineAnnealingLR(
                self.critic_optimizer, T_max=lr_maxt, eta_min=0.0
            )
            self.lambda_lr_scheduler = CosineAnnealingLR(
                self.LA_optimizer, T_max=lr_maxt, eta_min=0.0
            )

        self.state_dim = state_dim
        self.max_action = max_action
        self.action_dim = action_dim
        self.discount = discount
        self.tau = tau

        self.target_kl = target_kl
        self.device = device
        self.max_q_backup = max_q_backup

    def step_ema(self):
        if self.step < self.step_start_ema:
            return
        self.ema.update_model_average(self.ema_model, self.actor)

    def train(self, replay_buffer, iterations, batch_size=100, log_writer=None):
        metric = {
            "kl_loss": [],
            # "ql_loss": [],
            "actor_loss": [],
            "critic_loss": [],
            "Lambda": [],
        }
        for _ in range(iterations):
            # Sample replay buffer / batch
            state, action, next_state, reward, not_done = replay_buffer.sample(
                batch_size
            )

            """ Q Training """
            current_q1, current_q2 = self.critic(state, action)

            if self.max_q_backup:
                next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0)
                next_action_rpt = self.ema_model(next_state_rpt)

                next_action_rpt = (next_action_rpt).clamp(
                    -self.max_action, self.max_action
                )
                target_q1, target_q2 = self.critic_target(
                    next_state_rpt, next_action_rpt
                )
                target_q1 = target_q1.view(batch_size, 10).max(dim=1, keepdim=True)[0]
                target_q2 = target_q2.view(batch_size, 10).max(dim=1, keepdim=True)[0]
                target_q = torch.min(target_q1, target_q2)
            else:
                next_action = (self.ema_model(next_state)).clamp(
                    -self.max_action, self.max_action
                )
                target_q1, target_q2 = self.critic_target(next_state, next_action)
                target_q = torch.min(target_q1, target_q2)

            target_q = (reward + not_done * self.discount * target_q).detach()

            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(
                current_q2, target_q
            )

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            # if self.grad_norm > 0:
            critic_grad_norms = nn.utils.clip_grad_norm_(
                self.critic.parameters(), max_norm=self.grad_norm, norm_type=2
            )
            self.critic_optimizer.step()

            # training policy every policy_freq steps

            if self.step % self.policy_freq == 0:
                """Policy Training"""
                # print(state.shape)
                kl_loss = self.actor.loss(action, state)
                new_action = self.actor(state)

                q1_new_action, q2_new_action = self.critic(state, new_action)
                if np.random.uniform() > 0.5:
                    q_loss = -q1_new_action.mean() / q2_new_action.abs().mean().detach()
                else:
                    q_loss = -q2_new_action.mean() / q1_new_action.abs().mean().detach()
                # q_loss = - q1_new_action.mean()
                actor_loss = (
                    self.LA.clamp(self.LA_min, self.LA_max).detach() * kl_loss + q_loss
                )

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                # if self.grad_norm > 0:
                actor_grad_norms = nn.utils.clip_grad_norm_(
                    self.actor.parameters(), max_norm=self.grad_norm, norm_type=2
                )
                self.actor_optimizer.step()

                """ Lambda loss"""

                LA_loss = (self.target_kl - kl_loss).detach() * self.LA
                self.LA_optimizer.zero_grad()
                LA_loss.backward()
                # if self.grad_norm > 0:
                LA_grad_norms = nn.utils.clip_grad_norm_(
                    self.LA, max_norm=self.grad_norm, norm_type=2
                )
                self.LA_optimizer.step()

                metric["actor_loss"].append(actor_loss.item())
                metric["kl_loss"].append(kl_loss.item())
                # metric["ql_loss"].append(q_loss.item())
                metric["critic_loss"].append(critic_loss.item())
                metric["Lambda"].append(self.LA.clamp(self.LA_min, self.LA_max).item())

            """ Step Target network """
            if self.step % self.update_ema_every == 0:
                self.step_ema()

            for param, target_param in zip(
                self.critic.parameters(), self.critic_target.parameters()
            ):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data
                )

            self.step += 1

            """ Log """
            if log_writer is not None:
                if self.grad_norm > 0:
                    log_writer.add_scalar(
                        "Actor Grad Norm", actor_grad_norms.max().item(), self.step
                    )
                    log_writer.add_scalar(
                        "Critic Grad Norm", critic_grad_norms.max().item(), self.step
                    )
                    log_writer.add_scalar(
                        "Lambda Grad Norm", LA_grad_norms.max().item(), self.step
                    )
                log_writer.add_scalar("KL Loss", kl_loss.item(), self.step)
                # log_writer.add_scalar("QL Loss", q_loss.item(), self.step)
                log_writer.add_scalar("Critic Loss", critic_loss.item(), self.step)
                log_writer.add_scalar(
                    "Target_Q Mean", target_q.mean().item(), self.step
                )
                log_writer.add_scalar(
                    "Lambda",
                    self.LA.clamp(self.LA_min, self.LA_max).item(),
                    self.step,
                )

        if self.lr_decay:
            self.actor_lr_scheduler.step()
            self.critic_lr_scheduler.step()
            self.lambda_lr_scheduler.step()

        return metric

    def sample_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        # print(state.shape)
        state_rpt = torch.repeat_interleave(state, repeats=50, dim=0)
        # print(state_rpt.shape)
        with torch.no_grad():
            action = self.actor.sample(state_rpt)
            # print(action.shape)
            q_value = self.critic_target.q_min(state_rpt, action).flatten()
            idx = torch.multinomial(F.softmax(q_value), 1)
            # print(idx.shape)
            # print(action[idx].cpu().data.numpy().flatten())
            # print(action[idx].cpu().data.numpy().flatten().shape)
            """
            Returns a tensor where each row contains num_samples indices sampled from the multinomial 
            probability distribution located in the corresponding row of tensor input.
            """
        return action[idx].cpu().data.numpy().flatten()

    def save_model(self, dir, id=None):
        if id is not None:
            torch.save(self.actor.state_dict(), f"{dir}/actor_{id}.pth")
            torch.save(self.critic.state_dict(), f"{dir}/critic_{id}.pth")
        else:
            torch.save(self.actor.state_dict(), f"{dir}/actor.pth")
            torch.save(self.critic.state_dict(), f"{dir}/critic.pth")

    def load_model(self, dir, id=None):
        if id is not None:
            self.actor.load_state_dict(torch.load(f"{dir}/actor_{id}.pth"))
            self.critic.load_state_dict(torch.load(f"{dir}/critic_{id}.pth"))
        else:
            self.actor.load_state_dict(torch.load(f"{dir}/actor.pth"))
            self.critic.load_state_dict(torch.load(f"{dir}/critic.pth"))
