import os
import sys
import copy
import numpy as np
import torch
import torch.nn.functional as F

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
MODULE_DIR = os.path.dirname(os.path.dirname(FILE_DIR))
sys.path.append(MODULE_DIR)

from modules.velap.rl.model_td3_bc import Actor, Critic
from modules.utils import contr_loss


class TD3BC():
    def __init__(
            self,
            z_dim,
            action_dim,
            goal_dim,
            n_qs=2,
            w_bc=0.1,
            w_bc_exp=0.0,
            w_dyn=0.1,
            w_ens_div=0.0,
            discount=0.96,
            lmda=0.9,
            tau=0.005,
            max_action=1.0,
            policy_noise=0.2,
            noise_clip=0.5,
            policy_freq=2,
            lr_policy=1e-4,
            lr_critic=1e-4,
            T_contr=1.0,
            model_enc=None,
            model_dyn=None,
            use_batch_norm=1,
            add_neg_noise_samples=False,
            dyn_loss_type="contrastive",
            dyn_stop_grad=0,
            device="cuda",
    ):
        self.device = device

        self.model_enc = model_enc
        self.model_dyn = model_dyn

        self.action_dim = action_dim
        self.max_action = max_action
        self.z_dim = z_dim
        self.goal_dim = goal_dim

        self.w_bc = w_bc
        self.w_bc_exp = w_bc_exp
        self.w_dyn = w_dyn
        self.w_ens_div = w_ens_div
        self.T_contr = T_contr
        self.n_qs = n_qs
        self.discount = discount
        self.lmda = lmda
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.dyn_stop_grad = dyn_stop_grad
        self.dyn_loss_type = dyn_loss_type
        self.add_neg_noise_samples = add_neg_noise_samples
        self.use_batch_norm = use_batch_norm
        self.total_it = 0

        # Create actor
        self.actor = Actor(self.z_dim, self.action_dim, self.goal_dim, self.max_action,
                           use_batch_norm=self.use_batch_norm).to(device)
        self.actor_target = copy.deepcopy(self.actor).to(self.device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr_policy)

        self.critic = Critic(self.z_dim, self.action_dim, self.goal_dim, self.n_qs,
                             use_batch_norm=self.use_batch_norm).to(device)
        self.critic_target = copy.deepcopy(self.critic).to(self.device)
        critic_params = list(self.critic.parameters())
        if self.model_dyn is not None and self.model_enc is not None:
            critic_params += list(self.model_dyn.parameters())
            critic_params += list(self.model_enc.parameters())
        self.critic_optimizer = torch.optim.Adam(critic_params, lr=lr_critic)

    def select_action(self, state, goal=None):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        if goal is not None:
            goal = torch.FloatTensor(goal.reshape(1, -1)).to(self.device)
        return self.actor(state, goal).cpu().data.numpy().flatten()

    def select_action_t(self, state, goal=None):
        return self.actor(state, goal)

    def compute_q_min_t(self, state, goal=None):
        qs, action = self.compute_qs_t(state, goal)
        q_min = torch.min(qs, 0)[0]
        return q_min, action

    def compute_qs_t(self, state, goal=None):
        action = self.actor(state, goal)
        qs = self.critic(state, action, goal)
        qs = torch.stack(qs, dim=0)
        return qs, action

    def compute_q_all_given_action_t(self, state, action, goal=None):
        qs = self.critic(state, action, goal)
        qs = torch.stack(qs, dim=0)
        q_min = torch.min(qs, 0)[0]
        q_std = torch.std(qs, 0)
        q_mean = torch.mean(qs, 0)
        return qs, q_min, q_std, q_mean, action

    def compute_q_all_t(self, state, goal=None):
        qs, action = self.compute_qs_t(state, goal)
        q_min = torch.min(qs, 0)[0]
        q_std = torch.std(qs, 0)
        q_mean = torch.mean(qs, 0)
        return qs, q_min, q_std, q_mean, action

    def train_rl_and_embedding(self, batch):
        self.total_it += 1
        metrics = {}

        n_step = batch['obs_n_step'].shape[0] - 1
        obs_n_step = batch['obs_n_step']
        exp_traj = batch['exp_traj'] if "exp_traj" in batch else None
        prop_n_step = batch['prop_n_step']
        action_n_step = batch["action_n_step"]
        reward = batch['reward']
        not_done = torch.logical_not(batch['done'])

        # Encoder states 
        z_n_step = []
        for i in range(n_step + 1):
            z_n_step.append(self.model_enc(obs_n_step[i]))

        action = action_n_step[-1]
        z = z_n_step[-2]
        z_next = z_n_step[-1]
        dz_next = z_next - z
        z_abs_diff = torch.abs(dz_next).detach()
        mean_z_abs_diff = z_abs_diff.mean(0)

        ### Critic Update ###
        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (
                    torch.randn_like(action) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip)

            action_next = (
                    self.actor_target(z_next.detach()) + noise
            ).clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Qs = self.critic_target(z_next.detach(), action_next)
            target_Q = torch.min(torch.stack(target_Qs, dim=0), dim=0)[0]
            target_Q = reward + not_done * self.discount * target_Q

        # Get current Q estimates
        current_Qs = self.critic(z, action)

        # Compute critic loss
        critic_loss = torch.sum(torch.stack([F.mse_loss(current_Q, target_Q) for current_Q in current_Qs]))

        # Compute dynamics loss
        z_tmp = z_n_step[0]
        z_pred = []
        dz_pred = []
        lmda = self.lmda ** np.arange(0, n_step)
        for i in range(n_step):
            dz = self.model_dyn(z_tmp, action_n_step[i])
            dz_pred.append(dz)
            z_tmp = z_tmp + dz
            z_pred.append(z_tmp)
        z_pred = torch.stack(z_pred, 0)
        z_gt = torch.stack(z_n_step[1:], 0)

        if self.dyn_stop_grad:
            z_gt = z_gt.detach()

        if self.dyn_loss_type == "mse":
            dyn_loss = [lmda[i] * F.mse_loss(z_pred[i], z_gt[i]) for i in range(n_step)]
            dyn_loss = torch.stack(dyn_loss).mean()
        elif self.dyn_loss_type == "contrastive":
            z_same_context = self.model_enc(batch['obs_same_context'])
            bs = z_gt[0].shape[0]
            z_neg_noise = z_gt + torch.randn(z_gt.shape, device=self.device) * \
                          torch.randint(1, 5, (z_gt.shape[0], z_gt.shape[1], 1)).to(self.device) * mean_z_abs_diff
            dyn_loss = [lmda[i] * contr_loss(anchor=z_gt[i],
                                             pos=z_pred[i],
                                             neg=[z_gt[i][torch.randperm(bs).to(self.device)],
                                                  z_neg_noise[i],
                                                  z_same_context],
                                             T=self.T_contr) for i in range(n_step)]

            dyn_loss = torch.stack(dyn_loss).mean()
        else:
            raise NotImplementedError

        # Ensemble diversity via mean vector
        ensemble_diversity = self.critic.compute_ensemble_diversity()

        # Compute total loss
        encoder_loss = critic_loss + self.w_dyn * dyn_loss

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        encoder_loss.backward()
        self.critic_optimizer.step()

        # For logging
        metrics["critic_loss"] = critic_loss
        metrics["dyn_loss"] = dyn_loss
        metrics["ensemble_diversity"] = ensemble_diversity
        self.norm_dz = torch.norm(z_n_step[0] - z_n_step[1], dim=1).mean().unsqueeze(-1)
        self.last_Qs = current_Qs
        self.last_z = z_n_step[0]
        self.last_z_next = z_n_step[1]
        self.dz_l2 = torch.norm(dz_next, dim=-1)
        self.z_abs_diff_batch = z_abs_diff.mean(1)

        # Update policy
        if self.total_it % self.policy_freq == 0:

            # Compute actor loss
            action_pred = self.actor(z.detach())
            Qs = self.critic(z.detach(), action_pred)
            Q = torch.min(torch.stack(Qs, dim=0), dim=0)[0]

            bc_loss = F.mse_loss(action, action_pred)
            actor_loss = -Q.mean() + self.w_bc * bc_loss

            if exp_traj is not None:
                bc_loss_exp = F.mse_loss(action[exp_traj], action_pred[exp_traj])
                actor_loss += self.w_bc_exp * bc_loss_exp
                metrics["bc_loss_exp"] = bc_loss_exp

            metrics["actor_loss"] = actor_loss
            metrics["bc_loss"] = bc_loss

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        return metrics

    def train_rl(self, batch):
        self.total_it += 1
        metrics = {}

        state = batch['state']
        exp_traj = batch['exp_traj'] if "exp_traj" in batch else None
        state_next = batch['state_next']
        action = batch["action"]
        reward = batch['reward']
        not_done = torch.logical_not(batch['done'])
        goal = batch['goal'] if "goal" in batch else None

        ### Critic Update ###
        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (
                    torch.randn_like(action) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip)

            a_next = (
                    self.actor_target(state_next, goal) + noise
            ).clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Qs = self.critic_target(state_next, a_next, goal)
            target_Q = torch.min(torch.stack(target_Qs, dim=0), dim=0)[0]
            target_Q = reward + not_done * self.discount * target_Q

        # Ensemble diversity via mean vector
        ensemble_diversity = self.critic.compute_ensemble_diversity()

        # Get current Q estimates
        current_Qs = self.critic(state, action, goal)

        # Compute critic loss
        critic_loss = torch.sum(torch.stack([F.mse_loss(current_Q, target_Q) for current_Q in current_Qs]))

        if self.n_qs > 2:
            critic_loss -= self.w_ens_div * ensemble_diversity

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # For logging
        metrics["critic_loss"] = critic_loss
        metrics["ensemble_diversity"] = ensemble_diversity
        self.last_Qs = current_Qs

        # Update policy
        if self.total_it % self.policy_freq == 0:

            # Compute actor loss
            action_pred = self.actor(state, goal)
            Qs = self.critic(state, action_pred, goal)
            Q = torch.min(torch.stack(Qs, dim=0), dim=0)[0]

            bc_loss = F.mse_loss(action, action_pred)
            policy_loss = -Q.mean()
            actor_loss = policy_loss + self.w_bc * bc_loss

            if exp_traj is not None and torch.any(exp_traj):
                bc_loss_exp = F.mse_loss(action[:exp_traj.shape[0]][exp_traj], action_pred[:exp_traj.shape[0]][exp_traj])
                actor_loss += self.w_bc_exp * bc_loss_exp
                metrics["bc_loss_exp"] = bc_loss_exp

            metrics["policy_loss"] = policy_loss
            metrics["actor_loss"] = actor_loss
            metrics["bc_loss"] = bc_loss

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        return metrics

    def save(self, dir, name="model"):
        torch.save(self.actor.state_dict(), os.path.join(dir, name + "_actor"))
        torch.save(self.critic.state_dict(), os.path.join(dir, name + "_critic"))

        if self.model_enc is not None and self.model_dyn is not None:
            torch.save(self.model_enc.state_dict(), os.path.join(dir, name + "_encoder"))
            torch.save(self.model_dyn.state_dict(), os.path.join(dir, name + "_dynamics"))

    def load(self, dir, name="model", type=""):
        self.actor.load_state_dict(torch.load(os.path.join(dir, name + "_actor" + type)))
        self.critic.load_state_dict(torch.load(os.path.join(dir, name + "_critic" + type)))
        self.actor_target = copy.deepcopy(self.actor)
        self.critic_target = copy.deepcopy(self.critic)
