import numpy as np
import torch
import torch.distributions as D
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.autograd import Variable
import os
import argparse

from datasets import Fast_MNIST, Fast_SVHN, Fast_3DShapes

from load_model import load_model_from_save_dict
from plotting import plot_sample_generations_from_each_cluster_torch_grid

import sys
from sklearn.manifold import TSNE
from utils import cluster_acc_old
from mcc import mean_corr_coef, mean_corr_coef_out_of_sample
from sklearn.cross_decomposition import CCA

import warnings

import pandas as pd

def compute_mcc_single():
    """
    Compute MCC
    """

    # Data
    train_data = Fast_MNIST('./data', train=True, download=True, device="cpu")  # before: torchvision.datasets.MNIST
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=512, shuffle=False, num_workers=0)  # must be 0 with GPU, good article: https://discuss.pytorch.org/t/cuda-initialization-error-when-dataloader-with-cuda-tensor/43390
    test_data = Fast_MNIST('./data', train=False, download=True, device="cpu")  # before: torchvision.datasets.MNIST
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=512, shuffle=False, num_workers=0)  # must be 0 with GPU, good article: https://discuss.pytorch.org/t/cuda-initialization-error-when-dataloader-with-cuda-tensor/43390

    # configs
    model_path = "trained_models/mnist_modified/"
    print("Model path =", model_path)

    model_preds = []
    model_reps = []

    for model in os.listdir(model_path):
        print("Working with model", model)
        # define device and load model
        mfcvae, args = load_model_from_save_dict(model_path + model, map_location="cpu")
        # changes model to evaluation mode (e.g. dropout, batch norm affected)
        with torch.no_grad():
            mfcvae.eval()
            # transfer model to device
            device = torch.device("cpu")
            mfcvae.device = device
            mfcvae = mfcvae.to(device)

            all_reps = []
            eval_loss = 0.
            y_true_list, y_pred_j_list = [], [[] for j in range(mfcvae.J_n_mixtures)]
            for batch_idx, (x, y_true) in enumerate(train_loader):
                # subselect y_true with chosen attributes in the case of celeba
                if batch_idx == args.n_test_batches:
                    break
                if args.model_type in ['fc_shared', 'fc_per_facet_enc_shared_dec', 'fc_vlae']:
                    x = x.view(x.size(0), -1).float()
                elif args.model_type in ['conv_vlae']:
                    x = x.float()
                x = Variable(x)

                x_hat, q_z_j_x_list, z_sample_q_z_j_x_list = mfcvae.forward(x, 0, batch_idx)
                if args.model_type in ['conv_vlae']:
                    x = x.view(x.size(0), -1)
                    x_hat = x_hat.view(x_hat.size(0), -1).float()
                loss, _, _, _, _, _, _, _ = mfcvae.compute_loss_5terms(x, x_hat, q_z_j_x_list, z_sample_q_z_j_x_list, 0)
                # reshape back to image shapes
                # x = x.view(-1, in_channels, height, width).permute(0, 2, 3, 1)
                # x_hat = x_hat.view(-1, in_channels, height, width).permute(0, 2, 3, 1)

                eval_loss += loss.data * len(x)

                # print(x_hat)
                all_reps.append(q_z_j_x_list[0].mean.detach().numpy())

            print("Eval loss", eval_loss / len(train_loader.dataset))

            # print(all_reps)
            all_reps = np.concatenate(all_reps, 0)
            # print(all_reps.shape)
            model_reps.append(all_reps)

    # print(model_reps)
    # avg_strong_mcc = 0
    # avg_mcc_strong_in = 0
    # avg_mcc_strong_out = 0
    # avg_weak_mcc_in = 0
    # avg_weak_mcc_out = 0
    all_strong_mccs = []
    all_weak_mccs = []
    n = len(model_reps)
    for i in range(n):
        for j in range(i + 1, n):
            rep1 = model_reps[i]
            rep2 = model_reps[j]
            cutoff = rep1.shape[0] // 2 # half the test dataset
            ii = np.arange(cutoff)
            iinot = np.arange(cutoff, 2 * cutoff)
            # print("\n\n\nComparing models", i, ",", j)
            # print("Slicing sample into first", cutoff, "samples")


            mcc_strong_out = mean_corr_coef_out_of_sample(x=rep1[ii], y=rep2[ii], x_test=rep1[iinot], y_test=rep2[iinot])
            mcc_strong_in = (mean_corr_coef(x=rep1[ii], y=rep2[ii]))
            # print("Strong MCC out =", mcc_strong_out)
            # print("Strong MCC in =", mcc_strong_in)
            # avg_mcc_strong_in += mcc_strong_in
            # avg_mcc_strong_out += mcc_strong_out
            all_strong_mccs.append(mcc_strong_out)

            # strong_mcc = mean_corr_coef(model_reps[i], model_reps[j])
            # # print("Strong MCC =", strong_mcc)
            # avg_strong_mcc += strong_mcc

            cca_dim = 5
            cca = CCA(n_components=cca_dim, max_iter=5000)
            cca.fit(rep1[ii], rep2[ii])
            res_out = cca.transform(rep1[iinot], rep2[iinot])
            mcc_weak_out = mean_corr_coef(res_out[0], res_out[1])
            res_in = cca.transform(rep1[ii], rep2[ii])
            mcc_weak_in = mean_corr_coef(res_in[0], res_in[1])
            # print('mcc weak in: ', mcc_weak_in, ' --- ccadim = ', cca_dim)
            # print('mcc weak out: ', mcc_weak_out, ' --- ccadim = ', cca_dim)
            # avg_weak_mcc_in += mcc_weak_in
            # avg_weak_mcc_out += mcc_weak_out
            all_weak_mccs.append(mcc_weak_out)

            # strong_mcc = np.round(strong_mcc, 2)
            # mcc_strong_in = np.round(mcc_strong_in, 2)
            # mcc_strong_out =  np.round(mcc_strong_out, 2)
            # mcc_weak_in = np.round(mcc_weak_in, 2)
            # mcc_weak_out = np.round(mcc_weak_out, 2)
            # print(f"({i}, {j}) & {strong_mcc} & {mcc_strong_in} & {mcc_strong_out} & {mcc_weak_in} & {mcc_weak_out}\\\\")

    # avg_strong_mcc /= n * (n - 1) / 2
    # avg_mcc_strong_in /= n * (n - 1) / 2
    # avg_mcc_strong_out /= n * (n - 1) / 2
    # avg_weak_mcc_in /= n * (n - 1) / 2
    # avg_weak_mcc_out /= n * (n - 1) / 2

    # avg_strong_mcc = np.round(avg_strong_mcc, 2)
    # avg_mcc_strong_in = np.round(avg_mcc_strong_in, 2)
    # avg_mcc_strong_out = np.round(avg_mcc_strong_out, 2)
    # avg_weak_mcc_in = np.round(avg_weak_mcc_in, 2)
    # avg_weak_mcc_out = np.round(avg_weak_mcc_out, 2)
    # print(f"Average & {avg_strong_mcc} & {avg_mcc_strong_in} & {avg_mcc_strong_out} & {avg_weak_mcc_in} & {avg_weak_mcc_out}")

    # print("Avg strong MCC", avg_strong_mcc)
    # print("Avg strong MCC in", avg_mcc_strong_in)
    # print("Avg strong MCC out", avg_mcc_strong_out)
    # print("Avg weak MCC in", avg_weak_mcc_in)
    # print("Avg weak MCC out", avg_weak_mcc_out)

    print("Avg strong MCC", np.mean(all_strong_mccs), "Std strong MCC", np.std(all_strong_mccs))
    print("Avg weak MCC", np.mean(all_weak_mccs), "Std weak MCC", np.std(all_weak_mccs))


if __name__ == '__main__':
    warnings.simplefilter(action='ignore', category=FutureWarning)
    np.set_printoptions(precision=2)
    compute_mcc_single()