import torch
import numpy as np
from scipy.stats import pearsonr
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import random


def collate_fn(batch):
    data_v, data_a, data_t, label_norm, label, dataset_flag = zip(*batch)

    sorted_indices = sorted(range(len(data_t)), key=lambda k: len(data_t[k]), reverse=True)

    data_v = [data_v[i] for i in sorted_indices]
    data_a = [data_a[i] for i in sorted_indices]
    data_t = [data_t[i] for i in sorted_indices]
    label_norm = [label_norm[i] for i in sorted_indices]
    label = [label[i] for i in sorted_indices]
    dataset_flag = [dataset_flag[i] for i in sorted_indices]

    padded_data_t = pad_sequence(data_t, batch_first=True, padding_value=0)

    # lengths = [len(seq) for seq in data_t]
    # packed_data_t = pack_padded_sequence(padded_data_t, lengths, batch_first=True, enforce_sorted=False)
    return torch.stack(data_v), torch.stack(data_a), padded_data_t, torch.tensor(label_norm), torch.tensor(label), torch.tensor(dataset_flag)


def collate_fn_val(batch):
    data_v, data_a, data_t, label_norm, label, dataset_flag = zip(*batch)

    sorted_indices = sorted(range(len(data_t)), key=lambda k: len(data_t[k]), reverse=True)

    data_v = [data_v[i] for i in sorted_indices]
    data_a = [data_a[i] for i in sorted_indices]
    data_t = [data_t[i] for i in sorted_indices]
    label_norm = [label_norm[i] for i in sorted_indices]
    label = [label[i] for i in sorted_indices]
    dataset_flag = [dataset_flag[i] for i in sorted_indices]

    padded_data_t = pad_sequence([seq.transpose(0, 1) for seq in data_t], batch_first=True, padding_value=0).transpose(2, 1)

    return torch.stack(data_v), torch.stack(data_a), padded_data_t, torch.tensor(label_norm), torch.tensor(label), torch.tensor(dataset_flag)


def kd_loss(pred, label, lt, ls):
    return torch.mean((1 - torch.abs(pred-label).unsqueeze(1)) * torch.abs(lt-ls))


def kl_loss(mu, sigma):
    return torch.mean(-0.5 * torch.sum(1 + sigma - mu ** 2 - torch.exp(sigma), 1), 0)


def contrast_homo_loss(f_a, f_v, f_t, label):
    min_homo = torch.mean(torch.cdist(f_a, f_v) + torch.cdist(f_t, f_v) + torch.cdist(f_a, f_t))
    pairwise_dist = torch.cdist(torch.cat((f_a, f_v, f_t), 1), torch.cat((f_a, f_v, f_t), 1))
    margin = torch.abs(label.unsqueeze(0) - label.unsqueeze(1))
    min_cls_contra = (torch.mean(margin * torch.clamp(30 * margin - pairwise_dist, min=0.0)) +
                      torch.mean((1 - margin) * pairwise_dist))
    return min_homo + min_cls_contra


def contrast_hete_loss(f_a, f_v, f_t, label):
    max_hete = torch.mean(torch.clamp(10-torch.cdist(f_a, f_v), min=0.0) +
                          torch.clamp(10-torch.cdist(f_t, f_v), min=0.0) +
                          torch.clamp(10-torch.cdist(f_a, f_t), min=0.0))
    pairwise_dist = torch.cdist(torch.cat((f_a, f_v, f_t), 1), torch.cat((f_a, f_v, f_t), 1))
    margin = torch.abs(label.unsqueeze(0) - label.unsqueeze(1))
    min_cls_contra = (torch.mean(margin * torch.clamp(30 * margin - pairwise_dist, min=0.0)) +
                      torch.mean((1 - margin) * pairwise_dist))
    return max_hete + min_cls_contra


def contrast_noise_loss(f_a, f_v, f_t):
    max_modal = torch.mean(torch.clamp(10 - torch.cdist(f_a, f_v), min=0.0) +
                           torch.clamp(10 - torch.cdist(f_t, f_v), min=0.0) +
                           torch.clamp(10 - torch.cdist(f_a, f_t), min=0.0))
    min_cls_contra = torch.mean(torch.cdist(torch.cat((f_a, f_v, f_t), 1), torch.cat((f_a, f_v, f_t), 1)))
    return max_modal + min_cls_contra


def orth_loss(s1, s2):
    dot_product = torch.mm(s1, s2.T)
    loss = torch.mean(torch.abs(dot_product))
    return loss


def std_base_loss(out_t, train_label):
    mse_loss_func = nn.MSELoss()
    mae_loss_func = nn.L1Loss()
    mse_t = mse_loss_func(out_t['score'], train_label)
    mae_t = mae_loss_func(out_t['score'], train_label)
    rec_t = (mse_loss_func(out_t['x_v_hete_dec'], out_t['x_v_hete'])
             + mse_loss_func(out_t['x_a_hete_dec'], out_t['x_a_hete'])
             + mse_loss_func(out_t['x_t_hete_dec'], out_t['x_t_hete'])
             + mse_loss_func(out_t['x_v_homo_dec'], out_t['x_v_homo'])
             + mse_loss_func(out_t['x_a_homo_dec'], out_t['x_a_homo'])
             + mse_loss_func(out_t['x_t_homo_dec'], out_t['x_t_homo'])
             + mse_loss_func(out_t['x_v_n_dec'], out_t['x_v_noise'])
             + mse_loss_func(out_t['x_a_n_dec'], out_t['x_a_noise'])
             + mse_loss_func(out_t['x_t_n_dec'], out_t['x_t_noise'])
             ) / 9.
    con_t = 0.01 * (contrast_homo_loss(out_t['x_a_homo'], out_t['x_v_homo'], out_t['x_t_homo'], train_label)
                 + contrast_hete_loss(out_t['x_a_hete'], out_t['x_v_hete'], out_t['x_t_hete'], train_label)
                 + contrast_noise_loss(out_t['x_a_noise'], out_t['x_v_noise'], out_t['x_t_noise']))
    orth_t = 0.001 * orth_loss(torch.cat((out_t['x_a_homo'], out_t['x_v_homo'], out_t['x_t_homo']), 1),
                               torch.cat((out_t['x_a_hete'], out_t['x_v_hete'], out_t['x_t_hete']), 1))
    kl_t = 0.001 * kl_loss(out_t['mu'], out_t['sigma'])
    return mse_t, mae_t, rec_t, con_t, orth_t, kl_t


def random_feature_A_masking(tensor):
    mask = torch.ones_like(tensor)
    flag = random.choices([1, 2, 3], k=3)
    t = tensor.shape[1]
    for i in flag:
        start = random.randint(0, t)
        end = random.randint(start, t)
        if i == 1:
            mask[:, start:start+end, :39] = 0
        elif i == 2:
            mask[:, start:start+end, 39:127] = 0
        else:
            mask[:, start:start+end, 127:] = 0
    return mask * tensor


def random_feature_V_masking(tensor):
    mask = torch.ones_like(tensor)
    flag = random.choices([1, 2, 3, 4, 5], k=5)
    t = tensor.shape[1]
    for i in flag:
        start = random.randint(0, t)
        end = random.randint(start, t)
        if i == 1:
            mask[:, start:start+end, :6] = 0
        elif i == 2:
            mask[:, start:start+end, 6:12] = 0
        elif i == 3:
            mask[:, start:start + end, 12:47] = 0
        elif i == 4:
            mask[:, start:start + end, 47:183] = 0
        else:
            mask[:, start:start+end, 183:] = 0
    return mask * tensor


def random_feature_T_masking(tensor):
    mask = torch.ones_like(tensor)
    t = tensor.shape[1]
    start = random.randint(0, t)
    end = random.randint(start, t)
    mask[:, start:start + end, :] = 0
    return mask * tensor


def corr_coef(tensor1, tensor2):
    # Calculate means
    mean_tensor1 = torch.mean(tensor1)
    mean_tensor2 = torch.mean(tensor2)
    
    # Calculate covariance and variances
    covariance = torch.mean((tensor1 - mean_tensor1) * (tensor2 - mean_tensor2))
    std_tensor1 = torch.std(tensor1)
    std_tensor2 = torch.std(tensor2)
    var1 = torch.var(tensor1)
    var2 = torch.var(tensor2)
    
    # Calculate Pearson Correlation Coefficient
    pcc = covariance / (std_tensor1 * std_tensor2)
    ccc = 2 * covariance / (var1 + var2 + (mean_tensor1 - mean_tensor2) ** 2)
    return pcc, ccc


def test_logging(score_list, pred_list, writer, val_flags, logging='Testing'):
    mse_loss_func = nn.MSELoss()
    mae_loss_func = nn.L1Loss()

    avec13_mae = mae_loss_func(score_list[50:100], pred_list[50:100])
    avec13_rmse = torch.sqrt(mse_loss_func(score_list[50:100], pred_list[50:100]))
    avec14_mae = mae_loss_func(score_list[:100], pred_list[:100])
    avec14_rmse = torch.sqrt(mse_loss_func(score_list[:100], pred_list[:100]))

    if logging == 'Testing':
        avec17_mae = mae_loss_func(score_list[100:147], pred_list[100:147])
        avec17_rmse = torch.sqrt(mse_loss_func(score_list[100:147], pred_list[100:147]))
        avec19_mae = mae_loss_func(score_list[147:], pred_list[147:])
        avec19_rmse = torch.sqrt(mse_loss_func(score_list[147:], pred_list[147:]))
        pcc_13, ccc_13 = corr_coef(pred_list[50:100], score_list[50:100])
        pcc_14, ccc_14 = corr_coef(pred_list[:100], score_list[:100])
        pcc_17, ccc_17 = corr_coef(pred_list[100:147], score_list[100:147])
        pcc_19, ccc_19 = corr_coef(pred_list[147:], score_list[147:])
    elif logging == 'Development':
        avec17_mae = mae_loss_func(score_list[100:135], pred_list[100:135])
        avec17_rmse = torch.sqrt(mse_loss_func(score_list[100:135], pred_list[100:135]))
        avec19_mae = mae_loss_func(score_list[135:], pred_list[135:])
        avec19_rmse = torch.sqrt(mse_loss_func(score_list[135:], pred_list[135:]))
    else:
        print('logging must be in "Development", "Testing".')

    print('Test MAE: avec13: {:.4f}, avec14: {:.4f}, avec17: {:.4f}, avec19: {:.4f}\n'
          '    RMSE: avec13: {:.4f}, avec14: {:.4f}, avec17: {:.4f}, avec19: {:.4f}'.format(
          avec13_mae, avec14_mae, avec17_mae, avec19_mae,
          avec13_rmse, avec14_rmse, avec17_rmse, avec19_rmse))
    writer.add_scalar('Val Loss/13 MAE', avec13_mae, global_step=val_flags)
    writer.add_scalar('Val Loss/13 RMSE', avec13_rmse, global_step=val_flags)
    writer.add_scalar('Val Loss/14 MAE', avec14_mae, global_step=val_flags)
    writer.add_scalar('Val Loss/14 RMSE', avec14_rmse, global_step=val_flags)
    writer.add_scalar('Val Loss/17 MAE', avec17_mae, global_step=val_flags)
    writer.add_scalar('Val Loss/17 RMSE', avec17_rmse, global_step=val_flags)
    writer.add_scalar('Val Loss/19 MAE', avec19_mae, global_step=val_flags)
    writer.add_scalar('Val Loss/19 RMSE', avec19_rmse, global_step=val_flags)
    if logging == 'Testing':
        print('     PCC: avec13: {:.4f}, avec14: {:.4f}, avec17: {:.4f}, avec19: {:.4f}\n'
              '     CCC: avec13: {:.4f}, avec14: {:.4f}, avec17: {:.4f}, avec19: {:.4f}'.format(
            pcc_13, pcc_14, pcc_17, pcc_19, ccc_13, ccc_14, ccc_17, ccc_19))
        writer.add_scalar('Val Loss/13 PCC', pcc_13, global_step=val_flags)
        writer.add_scalar('Val Loss/13 CCC', ccc_13, global_step=val_flags)
        writer.add_scalar('Val Loss/14 PCC', pcc_14, global_step=val_flags)
        writer.add_scalar('Val Loss/14 CCC', ccc_14, global_step=val_flags)
        writer.add_scalar('Val Loss/17 PCC', pcc_17, global_step=val_flags)
        writer.add_scalar('Val Loss/17 CCC', ccc_17, global_step=val_flags)
        writer.add_scalar('Val Loss/19 PCC', pcc_19, global_step=val_flags)
        writer.add_scalar('Val Loss/19 CCC', ccc_19, global_step=val_flags)


