import os
import math
import torch
import numpy as np
from torch import optim
import torch.nn.functional as F
from sklearn.mixture import GaussianMixture
# from sklearn.utils.linear_assignment_ import linear_assignment
from hungarian import linear_assignment as linear_assignment
import argparse 
from torch.utils.data import DataLoader, TensorDataset

from models import Autoencoder, VaDE

from draw import draw_all, draw_together

import matplotlib.pyplot as plt

from scipy.linalg import sqrtm, pinv
from torch.linalg import pinv as tpinv

import scipy
from numpy.linalg import svd
from scipy.optimize import linear_sum_assignment
from scipy.stats import spearmanr

import itertools
from mcc import mean_corr_coef, mean_corr_coef_out_of_sample
from sklearn.cross_decomposition import CCA

import warnings

def get_dataloader(data, labels, batch_size=128):
    dataloader=DataLoader(TensorDataset(data, labels), batch_size=batch_size, 
                          shuffle=False, num_workers=0)
    return dataloader



def get_mcc(args):
    path_token = "25_13_24_35"
    model_path = "saved_models/{}/".format(path_token)

    data = torch.load("outputs/{}/data.pth".format(path_token))
    labels = torch.load("outputs/{}/labels.pth".format(path_token))
    # print(data.shape, labels.shape)
    dataloader = get_dataloader(data, labels, batch_size=args.batch_size)

    model_reps = []

    for model in os.listdir(model_path):
        print(model)
        vade = VaDE(args.in_dim, args.latent_dim, args.n_classes, covariance = model[:4])
        vade.load_state_dict(torch.load(model_path + model, map_location="cpu"))

        encoded_data = []
        for x, true_label in dataloader:
            encoded_data_x, _ = vade.encode(x)
            encoded_data.extend(encoded_data_x.detach().numpy())
        encoded_data = np.vstack(encoded_data)
        # print(encoded_data.shape)
        model_reps.append(encoded_data)

    # for model in os.listdir(model_path):
    #     if model in ["labels.pth", "data.pth", "pretrained_parameters.pth"]:
    #         continue
    #     if (model[-3:] == "pth"):
    #         print("Loaded model", model)
    #         # rep = torch.load(model_path + model).numpy()
    #         rep = torch.load(model_path + model)
    #         model_reps.append(rep)

    # print(model_reps)
    # avg_strong_mcc = 0
    # avg_mcc_strong_in = 0
    # avg_mcc_strong_out = 0
    # avg_weak_mcc_in = 0
    # avg_weak_mcc_out = 0
    all_strong_mccs = []
    all_weak_mccs = []
    n = len(model_reps)
    for i in range(n):
        for j in range(i + 1, n):
            rep1 = model_reps[i]
            rep2 = model_reps[j]
            cutoff = rep1.shape[0] // 2 # half the test dataset
            ii = np.arange(cutoff)
            iinot = np.arange(cutoff, 2 * cutoff)
            # print("\n\n\nComparing models", i, ",", j)
            # print("Slicing sample into first", cutoff, "samples")


            mcc_strong_out = mean_corr_coef_out_of_sample(x=rep1[ii], y=rep2[ii], x_test=rep1[iinot], y_test=rep2[iinot])
            mcc_strong_in = (mean_corr_coef(x=rep1[ii], y=rep2[ii]))
            # print("Strong MCC out =", mcc_strong_out)
            # print("Strong MCC in =", mcc_strong_in)
            # avg_mcc_strong_in += mcc_strong_in
            # avg_mcc_strong_out += mcc_strong_out
            all_strong_mccs.append(mcc_strong_out)

            # strong_mcc = mean_corr_coef(model_reps[i], model_reps[j])
            # # print("Strong MCC =", strong_mcc)
            # avg_strong_mcc += strong_mcc

            cca_dim = 20
            cca = CCA(n_components=cca_dim, max_iter=5000)
            cca.fit(rep1[ii], rep2[ii])
            res_out = cca.transform(rep1[iinot], rep2[iinot])
            mcc_weak_out = mean_corr_coef(res_out[0], res_out[1])
            res_in = cca.transform(rep1[ii], rep2[ii])
            mcc_weak_in = mean_corr_coef(res_in[0], res_in[1])
            # print('mcc weak in: ', mcc_weak_in, ' --- ccadim = ', cca_dim)
            # print('mcc weak out: ', mcc_weak_out, ' --- ccadim = ', cca_dim)
            # avg_weak_mcc_in += mcc_weak_in
            # avg_weak_mcc_out += mcc_weak_out
            all_weak_mccs.append(mcc_weak_out)

            # strong_mcc = np.round(strong_mcc, 2)
            # mcc_strong_in = np.round(mcc_strong_in, 2)
            # mcc_strong_out =  np.round(mcc_strong_out, 2)
            # mcc_weak_in = np.round(mcc_weak_in, 2)
            # mcc_weak_out = np.round(mcc_weak_out, 2)
            # print(f"({i}, {j}) & {mcc_strong_in} & {mcc_strong_out} & {mcc_weak_in} & {mcc_weak_out}\\\\")

    # avg_strong_mcc /= n * (n - 1) / 2
    # avg_mcc_strong_in /= n * (n - 1) / 2
    # avg_mcc_strong_out /= n * (n - 1) / 2
    # avg_weak_mcc_in /= n * (n - 1) / 2
    # avg_weak_mcc_out /= n * (n - 1) / 2

    # avg_strong_mcc = np.round(avg_strong_mcc, 2)
    # avg_mcc_strong_in = np.round(avg_mcc_strong_in, 2)
    # avg_mcc_strong_out = np.round(avg_mcc_strong_out, 2)
    # avg_weak_mcc_in = np.round(avg_weak_mcc_in, 2)
    # avg_weak_mcc_out = np.round(avg_weak_mcc_out, 2)
    # print(f"Average & {avg_strong_mcc} & {avg_mcc_strong_in} & {avg_mcc_strong_out} & {avg_weak_mcc_in} & {avg_weak_mcc_out}")

    # print("Avg strong MCC", avg_strong_mcc)
    # print("Avg strong MCC in", avg_mcc_strong_in)
    # print("Avg strong MCC out", avg_mcc_strong_out)
    # print("Avg weak MCC in", avg_weak_mcc_in)
    # print("Avg weak MCC out", avg_weak_mcc_out)

    print("Avg strong MCC", np.mean(all_strong_mccs), "Std strong MCC", np.std(all_strong_mccs))
    print("Avg weak MCC", np.mean(all_weak_mccs), "Std weak MCC", np.std(all_weak_mccs))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=64,
                        help="Batch size")
    parser.add_argument("--in_dim", type=int, default=5, 
                        help="Input dimension")
    parser.add_argument("--latent_dim", type=int, default=2,
                        help="Latent dimension")
    parser.add_argument("--n_classes", type=int, default=3, 
                        help="Num classes")
    args = parser.parse_args()
    warnings.simplefilter(action='ignore', category=FutureWarning)
    get_mcc(args)