import os
import sys
import copy
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
MODULE_DIR = os.path.dirname(os.path.dirname(FILE_DIR))
sys.path.append(MODULE_DIR)

from modules.velap.rl.model_td3_bc import Actor, Critic
from modules.utils import contr_loss


class TD3BCEncoder():
    def __init__(
            self,
            z_dim,
            action_dim,
            goal_dim,
            n_qs=2,
            w_bc_low=0.001,
            w_bc_exp_low=0.0,
            w_bc_high=0.001,
            w_bc_exp_high=0.5,
            w_dyn=0.1,
            w_rl_low=1.0,
            w_rl_high=1.0,
            w_ens_div=0.0,
            discount=0.96,
            lmda=0.9,
            tau=0.005,
            max_action=1.0,
            policy_noise=0.2,
            noise_clip=0.5,
            policy_freq=2,
            lr_policy=1e-4,
            lr_critic=1e-4,
            T_contr=1.0,
            model_enc=None,
            model_dyn=None,
            use_batch_norm=1,
            add_neg_noise_samples=False,
            dyn_loss_type="contrastive",
            dyn_stop_grad=0,
            device="cuda",
    ):
        self.device = device

        self.model_enc = model_enc
        self.model_dyn = model_dyn

        self.action_dim = action_dim
        self.max_action = max_action
        self.z_dim = z_dim
        self.goal_dim = goal_dim

        self.w_rl_low = w_rl_low
        self.w_rl_high = w_rl_high
        self.w_bc_low = w_bc_low
        self.w_bc_exp_low = w_bc_exp_low
        self.w_bc_high = w_bc_high
        self.w_bc_exp_high = w_bc_exp_high
        self.w_dyn = w_dyn
        self.w_ens_div = w_ens_div
        self.T_contr = T_contr
        self.n_qs = n_qs
        self.discount = discount
        self.lmda = lmda
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.dyn_stop_grad = dyn_stop_grad
        self.dyn_loss_type = dyn_loss_type
        self.add_neg_noise_samples = add_neg_noise_samples
        self.use_batch_norm = use_batch_norm
        self.total_it = 0

        # Create agent high
        self.actor_high = Actor(self.z_dim, self.action_dim, self.goal_dim, self.max_action,
                           use_batch_norm=self.use_batch_norm).to(device)
        self.actor_target_high = copy.deepcopy(self.actor_high).to(self.device)
        self.actor_optimizer_high = torch.optim.Adam(self.actor_high.parameters(), lr=lr_policy)

        self.critic_high = Critic(self.z_dim, self.action_dim, self.goal_dim, 2,
                             use_batch_norm=self.use_batch_norm).to(device)
        self.critic_target_high = copy.deepcopy(self.critic_high).to(self.device)
        critic_params_high = list(self.critic_high.parameters())

        # Create agent low
        self.actor_low = Actor(self.z_dim, self.action_dim, self.z_dim, self.max_action,
                           use_batch_norm=self.use_batch_norm).to(device)
        self.actor_target_low = copy.deepcopy(self.actor_low).to(self.device)
        self.actor_optimizer_low = torch.optim.Adam(self.actor_low.parameters(), lr=lr_policy)

        self.critic_low = Critic(self.z_dim, self.action_dim, self.z_dim, self.n_qs,
                             use_batch_norm=self.use_batch_norm).to(device)
        self.critic_target_low = copy.deepcopy(self.critic_low).to(self.device)
        critic_params_low = list(self.critic_low.parameters())

        critic_params = critic_params_high + critic_params_low + \
                        list(self.model_dyn.parameters()) + \
                        list(self.model_enc.parameters()) + \
                        list(self.model_enc.parameters())
        self.critic_optimizer = torch.optim.Adam(critic_params, lr=lr_critic)

    def train_rl_and_embedding(self, batch):
        self.total_it += 1
        metrics = {}

        n_step = batch['obs_n_step'].shape[0] - 1
        obs_n_step = batch['obs_n_step']
        exp_traj = batch['exp_traj'] if "exp_traj" in batch else None
        obs_goal = batch['obs_goal']
        action_n_step = batch["action_n_step"]
        reward_high = batch['reward_high']
        not_done_high = torch.logical_not(batch['done_high'])
        reward_low = batch['reward_low']
        not_done_low = torch.logical_not(batch['done_low'])

        # Encoder states 
        z_n_step = []
        for i in range(n_step + 1):
            z_n_step.append(self.model_enc(obs_n_step[i]))

        # Encode goals
        z_goal = self.model_enc(obs_goal)

        action_high = action_n_step[-1]
        z_high = z_n_step[-2]
        z_next_high = z_n_step[-1]

        ### Critic Update High ###
        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (
                    torch.randn_like(action_high) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip)

            action_next = (
                    self.actor_target_high(z_next_high.detach()) + noise
            ).clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Qs = self.critic_target_high(z_next_high.detach(), action_next)
            target_Q = torch.min(torch.stack(target_Qs, dim=0), dim=0)[0]
            target_Q = reward_high + not_done_high * self.discount * target_Q

        # Get current Q estimates
        current_Qs = self.critic_high(z_high, action_high)

        # Compute critic loss
        critic_loss_high = torch.sum(torch.stack([F.mse_loss(current_Q, target_Q) for current_Q in current_Qs]))

        action_low = action_n_step[-1]
        z_low = z_n_step[-2]
        z_next_low = z_n_step[-1]

        dz_next = z_next_low - z_low
        z_abs_diff = torch.abs(dz_next).detach()
        mean_z_abs_diff = z_abs_diff.mean(0)
        z_same_context = self.model_enc(batch['obs_same_context'])

        # Augment rl data with noisy goals
        if self.add_neg_noise_samples:
            n = int(0.2 * z_low.shape[0])  # Adds 20 percent negative goal samples

            # Generate guassina noise based on mean absolute differences between z vectors in batch
            gaussian_noise = torch.normal(mean=torch.zeros((n, self.z_dim), device=self.device),
                                          std=2 * mean_z_abs_diff * torch.ones((n, self.z_dim), device=self.device))

            # Augment batch
            ids = torch.randperm(gaussian_noise.shape[0])
            base_goals = z_goal[not_done_low.squeeze()][ids]
            z_goal_aug = base_goals + gaussian_noise
            z_aug = z_low[ids]
            z_next_aug = z_next_low[ids]
            action_aug = action_low[ids]
            z_goal = torch.cat([z_goal, z_goal_aug], dim=0)
            z_low = torch.cat([z_low, z_aug], dim=0)
            z_next_low = torch.cat([z_next_low, z_next_aug], dim=0)
            action_low = torch.cat([action_low, action_aug], dim=0)
            reward_low = torch.cat([reward_low, torch.zeros((n, 1), device=self.device)], dim=0)
            not_done_low = torch.cat([not_done_low, torch.ones((n, 1), device=self.device)], dim=0)

        ### Critic Update Low ###
        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (
                    torch.randn_like(action_low) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip)

            action_next = (
                    self.actor_target_low(z_next_low.detach(), z_goal) + noise
            ).clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Qs = self.critic_target_low(z_next_low.detach(), action_next, z_goal)
            target_Q = torch.min(torch.stack(target_Qs, dim=0), dim=0)[0]
            target_Q = reward_low + not_done_low * self.discount * target_Q

        # Get current Q estimates
        current_Qs = self.critic_low(z_low, action_low, z_goal)

        # Compute critic loss
        critic_loss_low = torch.sum(torch.stack([F.mse_loss(current_Q, target_Q) for current_Q in current_Qs]))

        # Compute dynamics loss
        z_tmp = z_n_step[0]
        z_pred = []
        dz_pred = []
        lmda = self.lmda ** np.arange(0, n_step)
        for i in range(n_step):
            dz = self.model_dyn(z_tmp, action_n_step[i])
            dz_pred.append(dz)
            z_tmp = z_tmp + dz
            z_pred.append(z_tmp)
        z_pred = torch.stack(z_pred, 0)
        z_gt = torch.stack(z_n_step[1:], 0)

        if self.dyn_stop_grad:
            z_gt = z_gt.detach()

        if self.dyn_loss_type == "mse":
            dyn_loss = [lmda[i] * F.mse_loss(z_pred[i], z_gt[i]) for i in range(n_step)]
            dyn_loss = torch.stack(dyn_loss).mean()
        elif self.dyn_loss_type == "contrastive":
            bs = z_gt[0].shape[0]
            z_neg_noise = z_gt + torch.randn(z_gt.shape, device=self.device) * \
                          torch.randint(1, 5, (z_gt.shape[0], z_gt.shape[1], 1)).to(self.device) * mean_z_abs_diff
            dyn_loss = [lmda[i] * contr_loss(anchor=z_gt[i],
                                             pos=z_pred[i],
                                             neg=[z_gt[i][torch.randperm(bs).to(self.device)],
                                                  z_neg_noise[i],
                                                  z_same_context],
                                             T=self.T_contr) for i in range(n_step)]

            dyn_loss = torch.stack(dyn_loss).mean()
        else:
            raise NotImplementedError

        # Compute total loss
        encoder_loss = self.w_rl_low * critic_loss_low + self.w_rl_high * critic_loss_high + self.w_dyn * dyn_loss

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        encoder_loss.backward()
        self.critic_optimizer.step()

        # For logging
        metrics["critic_loss_low"] = critic_loss_low
        metrics["critic_loss_high"] = critic_loss_high
        metrics["dyn_loss"] = dyn_loss
        self.norm_dz = torch.norm(z_n_step[0] - z_n_step[1], dim=1).mean().unsqueeze(-1)
        self.last_Qs = current_Qs
        self.last_z = z_n_step[0]
        self.last_z_next = z_n_step[1]

        # Update policy high
        if self.total_it % self.policy_freq == 0:

            # Compute actor loss
            action_pred = self.actor_high(z_high.detach())
            Qs = self.critic_high(z_high.detach(), action_pred)
            Q = torch.min(torch.stack(Qs, dim=0), dim=0)[0]

            bc_loss = F.mse_loss(action_high, action_pred)
            actor_loss = -Q.mean() + self.w_bc_high * bc_loss

            if exp_traj is not None and torch.any(exp_traj):
                bc_loss_exp = F.mse_loss(action_high[exp_traj], action_pred[exp_traj])
                actor_loss += self.w_bc_exp_high * bc_loss_exp
                metrics["policy_high/bc_loss_exp"] = bc_loss_exp

            metrics["policy_high/actor_loss"] = actor_loss
            metrics["policy_high/bc_loss"] = bc_loss

            # Optimize the actor
            self.actor_optimizer_high.zero_grad()
            actor_loss.backward()
            self.actor_optimizer_high.step()

        # Update the frozen target models
        for param, target_param in zip(self.critic_high.parameters(), self.critic_target_high.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.actor_high.parameters(), self.actor_target_high.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        # Update policy low
        if self.total_it % self.policy_freq == 0:

            # Compute actor loss
            action_pred = self.actor_low(z_low.detach(), z_goal.detach())
            Qs = self.critic_low(z_low.detach(), action_pred, z_goal.detach())
            Q = torch.min(torch.stack(Qs, dim=0), dim=0)[0]

            bc_loss = F.mse_loss(action_low, action_pred)
            actor_loss = -Q.mean() + self.w_bc_low * bc_loss

            if exp_traj is not None and torch.any(exp_traj):
                bc_loss_exp = F.mse_loss(action_low[:exp_traj.shape[0]][exp_traj], action_pred[:exp_traj.shape[0]][exp_traj])
                actor_loss += self.w_bc_exp_low * bc_loss_exp
                metrics["policy_low/bc_loss_exp"] = bc_loss_exp

            metrics["policy_low/actor_loss"] = actor_loss
            metrics["policy_low/bc_loss"] = bc_loss

            # Optimize the actor
            self.actor_optimizer_low.zero_grad()
            actor_loss.backward()
            self.actor_optimizer_low.step()

        # Update the frozen target models
        for param, target_param in zip(self.critic_low.parameters(), self.critic_target_low.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.actor_low.parameters(), self.actor_target_low.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        return metrics

    def save(self, dir, name="model"):
        # Save policy high
        torch.save(self.actor_high.state_dict(), os.path.join(dir, name + "_actor_high"))
        torch.save(self.critic_high.state_dict(), os.path.join(dir, name + "_critic_high"))

        # Save policy low
        torch.save(self.actor_low.state_dict(), os.path.join(dir, name + "_actor_low"))
        torch.save(self.critic_low.state_dict(), os.path.join(dir, name + "_critic_low"))

        torch.save(self.model_enc.state_dict(), os.path.join(dir, name + "_encoder"))
        torch.save(self.model_dyn.state_dict(), os.path.join(dir, name + "_dynamics"))