import os
import numpy as np
import torch
from torch import nn, optim
from tqdm import tqdm

from networks import phi, VAE
from transporter import FeatureEncoder, PoseRegressor, RefineNet, Transporter
from utils import *

class Haptic_Repr:

    def __init__(self, N, a_dim, channels, frq, probabilistic, tau, segment, dim_touch, device):
        super(Haptic_Repr, self).__init__()

        self.device = device

        self.frq = frq
        self.probabilistic = probabilistic
        self.tau = tau
        self.segment = segment

        self.net = phi(N, channels, a_dim, self.probabilistic, 1, dim_touch, device).to(self.device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=1e-3)
        self.cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)


    def fit_Rep(self, batch, train=True):

        ot, ot1, at, pos, next_pos = batch[0], batch[1], batch[2], batch[3], batch[4]
        st = ot.to(self.device)
        st1 = ot1.to(self.device)
        at = at.to(self.device)
        pos = pos.to(self.device)
        next_pos = next_pos.to(self.device)

        zt, mu_ht, logvar_ht, touch_t = self.net(st)
        zt1, mu_ht1, logvar_ht1, touch_t1 = self.net(st1)

        ''' EQUIVARIANCE '''
        loss_Eqv = torch.mean(torch.sum(((zt1 - zt) - at) ** 2, -1))

        ''' CONTRASTIVE '''
        zht = torch.cat([zt.detach(), touch_t], -1)
        zht1 = torch.cat([zt1.detach(), touch_t1], -1)

        rnd_idx = np.arange(mu_ht.shape[0])
        np.random.shuffle(rnd_idx)

        pos_d = torch.sum((touch_t.detach() - touch_t1) ** 2, -1) / self.tau
        neg_d = torch.sum((torch.unsqueeze(zht[rnd_idx], 1).detach() - zht1) ** 2, -1) / self.tau
        loss_contrastive = pos_d + torch.logsumexp(-neg_d, 0)

        ''' HAPTIC '''
        if self.segment == 1:
            z_sampled = torch.rand(zt.shape[0], 1).to(self.device) * (zt1 - zt) + zt
        else:
            z_sampled = zt

        pos_haptic_error = (mu_ht[:, 0] - z_sampled.detach()) ** 2
        neg_haptic_error = (mu_ht[:, 0] - mu_ht1[:, 0].detach()) ** 2

        pos_exponent = torch.sum(pos_haptic_error / torch.exp(logvar_ht[:, 0]), -1)
        pos_det = logvar_ht.shape[-1] * torch.sum(logvar_ht[:, 0], -1)
        pos_constant = logvar_ht.shape[-1] * np.log(2 * np.pi)
        pos_neg_log_likelihood = 0.5 * (pos_exponent + pos_det + pos_constant)
        pos_haptic = pos_neg_log_likelihood

        neg_exponent = torch.sum(neg_haptic_error / torch.exp(logvar_ht1[:, 0].detach()), -1)
        neg_log_t = torch.sum(logvar_ht[:, 0], -1)
        neg_log_t1 = torch.sum(logvar_ht1[:, 0].detach(), -1)
        neg_ratio = torch.sum(torch.exp(logvar_ht[:, 0] - logvar_ht1[:, 0].detach()), -1)
        neg_kl = 0.5 * (neg_log_t1 - neg_log_t - logvar_ht.shape[-1] + neg_exponent + neg_ratio)
        neg_haptic = neg_kl

        err_inv = torch.sum((touch_t - touch_t1) ** 2, -1).detach()
        neg_idx = torch.argsort(err_inv).detach()
        split_idx = otsu(err_inv.detach().cpu().numpy())
        #
        # loss_haptic_pos = pos_haptic[neg_idx[int(pos_haptic.shape[0] * self.frq):]]
        # loss_haptic_neg = neg_haptic[neg_idx[:int(pos_haptic.shape[0] * self.frq)]]
        #

        loss_haptic_pos = pos_haptic[:split_idx]
        loss_haptic_neg = neg_haptic[split_idx:]

        loss_haptic = torch.mean(loss_haptic_pos) + torch.mean(loss_haptic_neg)

        ''' TOTAL LOSS AND OPT '''
        loss = loss_Eqv + torch.mean(loss_contrastive) + loss_haptic

        if train:

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        ''' LOGS '''

        avg_cos_sim = torch.mean(self.cos_sim((zt1 - zt), at)).detach().cpu().item()

        h1_err1 = torch.mean(torch.abs((mu_ht1[:, 0] - mu_ht[:, 0]) - (next_pos[:, 1] - pos[:, 1]))).detach().cpu().item()

        h1_err = h1_err1
        h1_pos_err = torch.mean(torch.abs(mu_ht[:, 0] - pos[:, 1])).detach().cpu().item()

        frq_pos = torch.mean((pos_haptic < neg_haptic.sum(-1)) * 1.).detach().cpu().item()

        const_t = torch.nonzero(torch.sum(next_pos[:, 1] - pos[:, 1], -1))[:, 0].detach().cpu().numpy()

        avg_pos_haptic = torch.mean(pos_haptic).detach().cpu().item()
        avg_neg_haptic = torch.mean(neg_haptic).detach().cpu().item()

        err_inv_n = err_inv.detach().cpu().numpy()

        avg_contrastive = torch.mean(loss_contrastive).detach().cpu().item()

        avg_pos_d_loss = torch.mean(pos_d).detach().cpu().item()
        avg_neg_d_loss = torch.mean(neg_d).detach().cpu().item()
        avg_haptic_loss = torch.mean(loss_haptic).detach().cpu().item()
        avg_loss = (loss_Eqv + torch.mean(loss_contrastive) + loss_haptic).detach().cpu().item()
        avg_equivariance_loss = loss_Eqv.detach().cpu().item()

        vec_pos = pos[:, 0] - pos[:, 1]
        vec_z = zt - mu_ht[:, 0]
        avg_rel_err = torch.mean(torch.sum((vec_pos - vec_z) ** 2, -1)).detach().cpu().item()

        metrics_logs = {'eqv_loss': avg_equivariance_loss,
                        'cos_sim': avg_cos_sim,
                        'avg_pos_d_loss': avg_pos_d_loss,
                        'avg_neg_d_loss': avg_neg_d_loss,
                        'avg_haptic_loss': avg_haptic_loss,
                        'avg_loss': avg_loss,
                        'pos_hapt': avg_pos_haptic,
                        'neg_hapt': avg_neg_haptic,
                        'avg_contrastive': avg_contrastive,
                        'frq_pos': frq_pos,
                        'h_mu_dist': mu_ht.detach().cpu().numpy(),  ###
                        'h_logvar_dist': logvar_ht.detach().cpu().numpy(),  ###
                        'z_dist': zt.detach().cpu().numpy(),  ###
                        'touch_dist': touch_t.detach().cpu().numpy(),  ###
                        'h1_err': h1_err,
                        'h1_pos_err': h1_pos_err,
                        'const_t': const_t,  ###
                        'avg_err_inv': err_inv_n,
                        'avg_rel_err': avg_rel_err}

        return metrics_logs

    def get_rep(self, ot, ot1):

        zt, mu_ht, logvar_ht, touch_t = self.net(ot)

        return zt.detach(), mu_ht.detach(), logvar_ht.detach()

    def save_model(self, path_dir, exp_name):
        fname_psi = path_dir+exp_name+".mdl"
        torch.save(self.net.state_dict(), fname_psi)

    def load_model(self, path_dir, exp_name):
        fname_psi = path_dir+exp_name+".mdl"
        state_dict_psi = torch.load(fname_psi, map_location=self.device)
        self.net.load_state_dict(state_dict_psi)
        self.net.eval()


class Transporter_Repr:

    def __init__(self, k, channels, device):
        super(Transporter_Repr, self).__init__()

        self.device = device

        feature_encoder = FeatureEncoder(channels)
        pose_regressor = PoseRegressor(channels, k)
        refine_net = RefineNet(channels)

        self.net = Transporter(feature_encoder, pose_regressor, refine_net).to(device)

        self.optimizer = optim.Adam(self.net.parameters(), lr=1e-3)

    def fit_Rep(self, batch, train=True):
        ot, ot1, at, pos, next_pos = batch[0], batch[1], batch[2], batch[3], batch[4]
        st = ot.to(self.device)
        st1 = ot1.to(self.device)
        at = at.to(self.device)
        pos = pos.to(self.device)
        next_pos = next_pos.to(self.device)

        st1_hat, zt = self.net(st, st1)

        loss = self.net.get_loss(st1, st1_hat)

        avg_loss = loss.detach().cpu().item()
        vec_pos = pos[:, 0] - pos[:, 1]
        vec_z = zt[:, 0] - zt[:, 1]
        avg_rel_err = torch.mean(torch.sum((vec_pos - vec_z) ** 2, -1)).detach().cpu().item()

        if train:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        metrics_logs = {'avg_loss': avg_loss,
                        'avg_rel_err': avg_rel_err}

        return metrics_logs

    def get_rep(self, ot, ot1):

        ot1_hat, zt = self.net(ot, ot1)

        return zt.detach(), None, None

    def save_model(self, path_dir, exp_name):
        fname_psi = path_dir + exp_name + ".mdl"
        torch.save(self.net.state_dict(), fname_psi)

    def load_model(self, path_dir, exp_name):
        fname_psi = path_dir + exp_name + ".mdl"
        state_dict_psi = torch.load(fname_psi, map_location=self.device)
        self.net.load_state_dict(state_dict_psi)
        self.net.eval()


class VAE_Repr:

    def __init__(self, N, z_dim, h_features, channels, vae_equi, device):
        super(VAE_Repr, self).__init__()

        self.device = device
        self.vae_equi = vae_equi

        self.net = VAE(N, channels, z_dim, h_features).to(device)

        self.optimizer = optim.Adam(self.net.parameters(), lr=1e-3)

    def fit_Rep(self, batch, train=True):
        ot, ot1, at, pos, next_pos = batch[0], batch[1], batch[2], batch[3], batch[4]
        st = ot.to(self.device)
        st1 = ot1.to(self.device)
        at = at.to(self.device)
        pos = pos.to(self.device)
        next_pos = next_pos.to(self.device)

        z1, mu1, logvar1, s_hat1 = self.net(st)
        z2, mu2, logvar2, s_hat2 = self.net(st1)

        loss_mse1, loss_kl1 = self.net.get_vae_loss(st, s_hat1, mu1, logvar1)
        loss_mse2, loss_kl2 = self.net.get_vae_loss(st1, s_hat2, mu2, logvar2)
        loss_eq = torch.mean(torch.sum((z2 - z1 - at) ** 2, -1))

        loss = loss_mse1 + loss_kl1 + loss_mse2 + loss_kl2 + self.vae_equi*loss_eq

        if train:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        avg_loss = loss.detach().cpu().item()
        avg_mse_loss = (loss_mse1 + loss_mse2).detach().cpu().item()
        avg_kl_loss = (loss_kl1 + loss_kl2).detach().cpu().item()
        avg_eq_loss = loss_eq.detach().cpu().item()
        vec_pos = pos[:, 0] - pos[:, 1]
        vec_z = z1 - mu1
        avg_rel_err = torch.mean(torch.sum((vec_pos - vec_z) ** 2, -1)).detach().cpu().item()

        metrics_logs = {'avg_loss': avg_loss,
                        'avg_mse_loss': avg_mse_loss,
                        'avg_kl_loss': avg_kl_loss,
                        'avg_eq_loss': avg_eq_loss,
                        'avg_rel_err': avg_rel_err}

        return metrics_logs

    def get_rep(self, ot, ot1):
        z1, mu1, logvar1, s_hat1 = self.net(ot)
        return z1.detach(), mu1.detach(), logvar1.detach()

    def save_model(self, path_dir, exp_name):
        fname_psi = path_dir + exp_name + ".mdl"
        torch.save(self.net.state_dict(), fname_psi)

    def load_model(self, path_dir, exp_name):
        fname_psi = path_dir + exp_name + ".mdl"
        state_dict_psi = torch.load(fname_psi, map_location=self.device)
        self.net.load_state_dict(state_dict_psi)
        self.net.eval()
