import numpy as np
import torch
from tqdm import trange
from torch import nn
import copy
import time

def to_device(x, device='cuda'):
    if torch.is_tensor(x):
        return x.to(device)
    elif type(x) is dict:
        return {k: to_device(v, device) for k, v in x.items()}
    else:
        print(f'Unrecognized type in `to_device`: {type(x)}')

def batch_to_device(batch, device='cuda:0'):
    vals = [to_device(getattr(batch, field), device) for field in batch._fields]
    return type(batch)(*vals)

@torch.jit.script
def compute_kernel(x, y):
    x_size = x.shape[0]#32
    y_size = y.shape[0]#32
    dim = x.shape[1]#16

    tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1)#(32,32,16)
    tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1)#(32,32,16)

    return torch.exp(-torch.mean((tiled_x - tiled_y)**2, dim=2)/dim*1.0) # (32,32)

@torch.jit.script
def compute_mmd(x, y):#(32,16)(32,16)
    x_kernel = compute_kernel(x, x)#(32,32)
    y_kernel = compute_kernel(y, y)#(32,32)
    xy_kernel = compute_kernel(x, y)#(32,32)
    return torch.mean(x_kernel) + torch.mean(y_kernel) - 2*torch.mean(xy_kernel)

class EMA():
    '''
        empirical moving average
    '''
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class AllTrainer():
    def __init__(
        self,
        en_model, # encoder
        de_model, # diffusion model
        optimizer,
        batch_size,
        get_batch,
        device,
        et_optimizer,
        w,
        w_std,
        w_optimizer,
        repre_type,
        phi_norm_loss_ratio,
        info_loss_weight,
        sg,
    ):
        super().__init__()
        self.optimizer = optimizer
        self.batch_size = batch_size
        # self.get_batch = get_batch
        self.dataloader = get_batch
        self.diagnostics = dict()
        self.en_model = en_model
        self.de_model = de_model
        self.info_loss_weight = info_loss_weight
        
        self.device = device
        self.ema = EMA(0.995)
        self.ema_model = copy.deepcopy(self.de_model)
        self.reset_parameters()
        self.step = 1

        self.et_optimizer = et_optimizer
        self.repre_type = repre_type
        self.w = w
        self.w_std = w_std
        self.w_optimizer = w_optimizer
        self.phi_loss = nn.MSELoss()
        self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
        self.phi_norm_loss_ratio = phi_norm_loss_ratio
        self.sg = sg

    def reset_parameters(self):
        self.ema_model.load_state_dict(self.de_model.state_dict())

    def step_ema(self):
        if self.step < 2000:
            self.reset_parameters()
            return
        self.ema.update_model_average(self.ema_model, self.de_model)

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):

        diffusion_losses, inv_losses, info_losses, simi_losses = [], [], [], []
        regress_losses, phi_norm_losses, w_losses, si_train_losses = [], [], [], []
        logs = dict()
        train_start = time.time()

        self.en_model.train()
        self.de_model.train()
        for i in trange(num_steps, desc='train_step'):
            if self.repre_type == 'none': # train decision diffuser
                diffusion_loss, inv_loss, regress_loss, phi_norm_loss, w_loss, info_loss = self.train_dd()
            elif self.repre_type == 'vec': # train pref diffuser with vector representations
                diffusion_loss, inv_loss, regress_loss, phi_norm_loss, w_loss, info_loss = self.train_with_vec()
            elif self.repre_type == 'dist': # train pref diffuser with distributional representations
                diffusion_loss, inv_loss, regress_loss, phi_norm_loss, w_loss, info_loss = self.train_with_dist()

            diffusion_losses.append(diffusion_loss)
            inv_losses.append(inv_loss)
            # simi_losses.append(si_loss)
            info_losses.append(info_loss)
            regress_losses.append(regress_loss)
            phi_norm_losses.append(phi_norm_loss)
            w_losses.append(w_loss)
            # si_train_losses.append(si_loss)

        logs['training/time'] = time.time() - train_start
        logs['training/diffusion_loss_mean'] = np.mean(diffusion_losses)
        logs['training/diffusion_loss_std'] = np.std(diffusion_losses)
        logs['training/inv_loss_mean'] = np.mean(inv_losses)
        logs['training/inv_loss_std'] = np.std(inv_losses)
        logs['training/info_loss_mean'] = np.mean(info_losses)
        logs['training/info_loss_std'] = np.std(info_losses)
        # logs['training/simi_loss_mean'] = np.mean(simi_losses)
        # logs['training/simi_loss_std'] = np.std(simi_losses)
        logs['training/pref_loss_mean'] = np.mean(regress_losses)
        # logs['training/pref_loss_std'] = np.std(regress_losses)
        logs['training/phi_norm_loss_mean'] = np.mean(phi_norm_losses)
        # logs['training/phi_norm_loss_std'] = np.std(phi_norm_losses)
        logs['training/w_loss_mean'] = np.mean(w_losses)
        # logs['training/w_loss_std'] = np.std(w_losses)
        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

    def train_with_dist(self):
        # for training pref_diffuser
        # start = time.time()
        batch1 = next(self.dataloader)
        batch2 = next(self.dataloader)
        # print('--------load_2_batch:', time.time() - start)
        states_1, states_2 = batch1['samples'].to(self.device), batch2['samples'].to(self.device) # (32, 100, 11 )
        actions_1, actions_2 = batch1['actions'].to(self.device), batch2['actions'].to(self.device) # (32,100,3)
        rtg_1, rtg_2 = batch1['returns'].to(self.device), batch2['returns'].to(self.device) # (32,100)
        timesteps_1, timesteps_2 = batch1['timesteps'].to(self.device), batch2['timesteps'].to(self.device) # (32,100)
        mask_1, mask_2 = batch1['masks'].to(self.device), batch2['masks'].to(self.device) # (32,100)
        lb = (rtg_1[:, 0] - rtg_2[:, 0]) >= 0
        rb = (rtg_2[:, 0] - rtg_1[:, 0]) > 0

        # print('before_encoding:', time.time()-start)
        states = torch.cat([states_1, states_2], dim=0)
        actions = torch.cat([actions_1, actions_2], dim=0)
        timesteps = torch.cat([timesteps_1, timesteps_2], dim=0)
        masks = torch.cat([mask_1, mask_2], dim=0)
        phi_mean, phi_std = self.en_model.forward(states, timesteps, masks)
        
        positive_mean = torch.cat((phi_mean[:self.batch_size][lb], phi_mean[self.batch_size:][rb]), 0)
        negative_mean = torch.cat((phi_mean[self.batch_size:][lb], phi_mean[:self.batch_size][rb]), 0)
        positive_std = torch.cat((phi_std[:self.batch_size][lb], phi_std[self.batch_size:][rb]), 0)
        negative_std = torch.cat((phi_std[self.batch_size:][lb], phi_std[:self.batch_size][rb]), 0)
        
        # phi_mean_1, phi_std_1 = self.en_model.forward(states_1, timesteps_1, mask_1)
        # phi_mean_2, phi_std_2 = self.en_model.forward(states_2, timesteps_2, mask_2)
        # positive_mean = torch.cat([phi_mean_1[lb], phi_mean_2[rb]], 0)
        # negative_mean = torch.cat([phi_mean_2[lb], phi_mean_1[rb]], 0)
        # positive_std = torch.cat([phi_std_1[lb], phi_std_2[rb]], 0)
        # negative_std = torch.cat([phi_std_2[lb], phi_std_1[rb]], 0)
        
        positive_dist = torch.distributions.MultivariateNormal(loc=positive_mean, covariance_matrix=torch.diag_embed(torch.exp(positive_std)))
        negative_dist = torch.distributions.MultivariateNormal(loc=negative_mean, covariance_matrix=torch.diag_embed(torch.exp(negative_std)))
        w_std = torch.clamp(self.w_std, min=-20, max=2).detach()
        anchor_dist = torch.distributions.MultivariateNormal(loc=self.w.detach(), covariance_matrix=torch.diag_embed(torch.exp(w_std)))
        positive_kl = torch.distributions.kl.kl_divergence(anchor_dist, positive_dist).mean()
        negative_kl = torch.distributions.kl.kl_divergence(anchor_dist, negative_dist).mean()
        kl_loss = positive_kl + 1.0 / negative_kl
        anchor_mean = self.w.expand(positive_mean.shape[0], -1).detach()
        trip_loss = self.triplet_loss(anchor_mean, positive_mean, negative_mean)
        phi_norm_loss = self.phi_loss(torch.norm(phi_mean, dim=1), torch.ones(2*self.batch_size).to(self.device))
        # phi_norm_loss = self.phi_loss(torch.norm(phi_mean_1, dim=1), torch.ones(self.batch_size).to(self.device)) + self.phi_loss(torch.norm(phi_mean_2, dim=1), torch.ones(self.batch_size).to(self.device))
        pref_loss = trip_loss + kl_loss + self.phi_norm_loss_ratio * phi_norm_loss
        # pref_loss = kl_loss + self.phi_norm_loss_ratio * phi_norm_loss
        
        # states = torch.cat([states_1, states_2], dim=0)
        # actions = torch.cat([actions_1, actions_2], dim=0)
        # phi_mean = torch.cat([phi_mean_1, phi_mean_2], dim=0)
        # phi_std = torch.cat([phi_std_1, phi_std_2], dim=0)
        conditions = states[:,0,:] # condition在当前状态下，用于做planning
        trajectories = torch.concat([actions, states], dim=-1)  # 将这段state和action合并
        if self.sg: # stop gradients
            diff_loss, inv_loss, info_loss = self.de_model.loss(trajectories, conditions, phi_mean.detach(), phi_std.detach())
        else:
            diff_loss, inv_loss, info_loss = self.de_model.loss(trajectories, conditions, phi_mean, phi_std)
        diffusion_loss = diff_loss + inv_loss
        diffusion_loss += pref_loss
        # print('compute diffusion loss:', time.time()-start)

        self.optimizer.zero_grad()
        self.et_optimizer.zero_grad()
        diffusion_loss.backward()
        self.optimizer.step()
        self.et_optimizer.step()
        if self.step % 10 == 0:
            self.step_ema()
        self.step += 1
        
        # print('update encoder and diffusion model:', time.time()-start)
        
        # when encoder outputs distributions
        phi_mean, phi_std = self.en_model.forward(states, timesteps, masks)
        positive_mean = torch.cat((phi_mean[:self.batch_size][lb], phi_mean[self.batch_size:][rb]), 0).detach() # (60, 16)
        negative_mean = torch.cat((phi_mean[self.batch_size:][lb], phi_mean[:self.batch_size][rb]), 0).detach() # (60, 16)
        positive_std = torch.cat((phi_std[:self.batch_size][lb], phi_std[self.batch_size:][rb]), 0).detach()
        negative_std = torch.cat((phi_std[self.batch_size:][lb], phi_std[:self.batch_size][rb]), 0).detach()
        
        # phi_mean_1, phi_std_1 = self.en_model.forward(states_1, timesteps_1, mask_1)
        # phi_mean_2, phi_std_2 = self.en_model.forward(states_2, timesteps_2, mask_2)
        # positive_mean = torch.cat([phi_mean_1[lb], phi_mean_2[rb]], 0).detach()
        # negative_mean = torch.cat([phi_mean_2[lb], phi_mean_1[rb]], 0).detach()
        # positive_std = torch.cat([phi_std_1[lb], phi_std_2[rb]], 0).detach()
        # negative_std = torch.cat([phi_std_2[lb], phi_std_1[rb]], 0).detach()
        
        positive_dist = torch.distributions.MultivariateNormal(loc=positive_mean, covariance_matrix=torch.diag_embed(torch.exp(positive_std)))
        negative_dist = torch.distributions.MultivariateNormal(loc=negative_mean, covariance_matrix=torch.diag_embed(torch.exp(negative_std)))
        w_std = torch.clamp(self.w_std, min=-20, max=2)
        anchor_dist = torch.distributions.MultivariateNormal(loc=self.w, covariance_matrix=torch.diag_embed(torch.exp(w_std)))
        positive_kl = torch.distributions.kl.kl_divergence(anchor_dist, positive_dist).mean()
        negative_kl = torch.distributions.kl.kl_divergence(anchor_dist, negative_dist).mean()
        
        kl_loss = positive_kl + 1.0 / negative_kl
        anchor_mean = self.w.expand(positive_mean.shape[0], -1)
        trip_loss = self.triplet_loss(anchor_mean, positive_mean, negative_mean)
        w_loss = trip_loss + kl_loss
        # w_loss = kl_loss
        self.w_optimizer.zero_grad()
        w_loss.backward()
        self.w_optimizer.step()
        
        # print('update w:', time.time()-start)

        return diff_loss.detach().cpu().item(), inv_loss.detach().cpu().item(), pref_loss.detach().cpu().item(), phi_norm_loss.detach().cpu().item(), w_loss.detach().cpu().item(), info_loss.detach().cpu().item()

    def train_with_vec(self):
        # for training pref_diffuser
        batch1 = next(self.dataloader)
        batch2 = next(self.dataloader)
        states_1, states_2 = batch1['samples'].to(self.device), batch2['samples'].to(self.device) # (32, 100, 11 )
        actions_1, actions_2 = batch1['actions'].to(self.device), batch2['actions'].to(self.device) # (32,100,3)
        rtg_1, rtg_2 = batch1['returns'].to(self.device), batch2['returns'].to(self.device) # (32,100)
        timesteps_1, timesteps_2 = batch1['timesteps'].to(self.device), batch2['timesteps'].to(self.device) # (32,100)
        mask_1, mask_2 = batch1['masks'].to(self.device), batch2['masks'].to(self.device) # (32,100)

        lb = (rtg_1[:, 0] - rtg_2[:, 0]) >= 0
        rb = (rtg_2[:, 0] - rtg_1[:, 0]) > 0

        # pref loss and phi norm loss, when representation is vector
        phi_1 = self.en_model.forward(states_1, timesteps_1, mask_1)
        phi_2 = self.en_model.forward(states_2, timesteps_2, mask_2)
        phi_norm_loss = (self.phi_loss(torch.norm(phi_1, dim=1), torch.ones(self.batch_size).to(self.device))
                + self.phi_loss(torch.norm(phi_2, dim=1), torch.ones(self.batch_size).to(self.device)))
        positive = torch.cat((phi_1[lb], phi_2[rb]), 0)
        negative = torch.cat((phi_2[lb], phi_1[rb]), 0)
        anchor = self.w.expand(positive.shape[0], -1).detach()
        trip_loss = self.triplet_loss(anchor, positive, negative)
        pref_loss = trip_loss + self.phi_norm_loss_ratio * phi_norm_loss

        # update diffusion
        states = torch.cat([states_1, states_2], dim=0)
        actions = torch.cat([actions_1, actions_2], dim=0)
        phis = torch.cat([phi_1, phi_2], dim=0)
        conditions = states[:,0,:] # condition在当前状态下，用于做planning
        trajectories = torch.concat([actions, states], dim=-1)  # 将这段state和action合并
        diff_loss, inv_loss, info_loss = self.de_model.loss(trajectories, conditions, phis) # compute loss
        
        # # add ood info loss
        # noise =  torch.randn(phis.shape, device=self.device)
        # anchor_mean = self.w.expand(phis.shape[0], -1) + 0.01 * noise
        # generated_phi_mean, generated_phi_std = self.de_model.generate(conditions, anchor_mean) # (batch, 16)
        # info_loss = compute_mmd(anchor_mean, generated_phi_mean)

        diffusion_loss = diff_loss + inv_loss
        # diffusion_loss += self.info_loss_weight * info_loss
        diffusion_loss += pref_loss

        self.optimizer.zero_grad()
        self.et_optimizer.zero_grad()
        diffusion_loss.backward()
        self.optimizer.step()
        self.et_optimizer.step()
        if self.step % 10 == 0:
            self.step_ema()
        self.step += 1

        # update anchor vector
        phi_1 = self.en_model.forward(states_1, timesteps_1, mask_1)
        phi_2 = self.en_model.forward(states_2, timesteps_2, mask_2)
        positive = torch.cat((phi_1[lb], phi_2[rb]), 0)
        negative = torch.cat((phi_2[lb], phi_1[rb]), 0)
        anchor = self.w.expand(positive.shape[0], -1)
        w_loss = self.triplet_loss(anchor, positive, negative)
        self.w_optimizer.zero_grad()
        w_loss.backward()
        self.w_optimizer.step()

        return diff_loss.detach().cpu().item(), inv_loss.detach().cpu().item(), pref_loss.detach().cpu().item(), phi_norm_loss.detach().cpu().item(), w_loss.detach().cpu().item(), info_loss.detach().cpu().item()

    def train_dd(self):
        # for reproducing decision diffuser
        batch = next(self.dataloader)
        states = batch['samples'].to(self.device)
        actions = batch['actions'].to(self.device)
        conditions = batch['conditions'][0].to(self.device)
        phis = batch['returns'].to(self.device)
        trajectories = torch.concat([actions, states], dim=-1)  # 将这段state和action合并
        diff_loss, inv_loss, _ = self.de_model.loss(trajectories, conditions, phis) # compute loss

        diffusion_loss = diff_loss + inv_loss

        self.optimizer.zero_grad()
        diffusion_loss.backward()
        self.optimizer.step()
        if self.step % 10 == 0:
            self.step_ema()
        self.step += 1

        return diff_loss.detach().cpu().item(), inv_loss.detach().cpu().item(), 0., 0., 0., 0.0
    
    # maximizing the mutual information between w and x_0
        # generated_phi_mean, generated_phi_std = self.de_model.generate(conditions, phis) # (batch, 16)
        # generated_phi_dist = torch.distributions.MultivariateNormal(loc=generated_phi_mean, 
        #                                                         covariance_matrix=torch.diag_embed(torch.exp(generated_phi_std)))
        # phi_dist = torch.distributions.MultivariateNormal(loc=phis, 
        #                                                       covariance_matrix=torch.diag_embed(torch.exp(phis_std)))
        # info_loss = torch.distributions.kl_divergence(generated_phi_dist, phi_dist).mean()
        # diff_loss += 0.1 * info_loss
        # maximizing the mutual information between w* and x_{t-10}, ood info loss
        # noise =  torch.randn(phis.shape, device=self.device)
        # anchor_mean = self.w.expand(phis.shape[0], -1) + 0.01 * noise
        # generated_phi_mean, generated_phi_std = self.de_model.generate(conditions, anchor_mean) # (batch, 16)
        # generated_phi_dist = torch.distributions.MultivariateNormal(loc=generated_phi_mean, 
        #                                                         covariance_matrix=torch.diag_embed(torch.exp(generated_phi_std)))
        # phi_dist = torch.distributions.MultivariateNormal(loc=anchor_mean, 
        #                                                       covariance_matrix=torch.diag_embed(torch.exp(self.w_std)))
        # info_loss = torch.distributions.kl_divergence(generated_phi_dist, phi_dist).mean()
        # diffusion_loss += self.info_loss_weight * info_loss