import os
import math
import torch
import numpy as np
from torch import optim
import torch.nn.functional as F
from sklearn.mixture import GaussianMixture
# from sklearn.utils.linear_assignment_ import linear_assignment
from hungarian import linear_assignment as linear_assignment
import argparse

import statistics

from models import Autoencoder, VaDE

from draw import draw_all, draw_together

import matplotlib.pyplot as plt

from scipy.linalg import sqrtm, pinv
from torch.linalg import pinv as tpinv

import scipy
from numpy.linalg import svd
from scipy.optimize import linear_sum_assignment
from scipy.stats import spearmanr

import itertools

# def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, rate):
#     # code from Johnson et. al. (2016)
#     rads = np.linspace(0, 2*np.pi, num_classes, endpoint=False)

#     np.random.seed(1)

#     features = np.random.randn(num_classes*num_per_class, 2) \
#         * np.array([radial_std, tangential_std])
#     features[:,0] += 1.
#     labels = np.repeat(np.arange(num_classes), num_per_class)

#     angles = rads[labels] + rate * np.exp(features[:,0])
#     rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
#     rotations = np.reshape(rotations.T, (-1, 2, 2))

#     feats = 10 * np.einsum('ti,tij->tj', features, rotations)

#     data = np.random.permutation(np.hstack([feats, labels[:, None]]))
#     labels = np.zeros((num_classes*num_per_class, ))
#     labels = data[:, 2].astype(int)
#     return torch.Tensor(data[:, 0:2]), torch.Tensor(labels).int()

def _get_l2_loss(n_classes, mu_prior1, mu_prior2, means1, means2, covar1, covar2):
    # https://stats.stackexchange.com/questions/71879/distance-between-two-gaussian-mixtures-to-evaluate-cluster-solutions

    def get_integral(mean1, var1, mean2, var2):
        ret = 1. / np.sqrt(np.linalg.det(2 * np.pi * (var1 + var2)))
        ret *= np.exp(-.5 * (mean1 - mean2).T @ (np.linalg.inv(var1 + var2) @ (mean1 - mean2)))
        # ret = -.5 * (mean1 - mean2).T @ (np.linalg.inv(var1 + var2) @ (mean1 - mean2))
        # ret -= .5 * np.log(np.linalg.det(2 * np.pi * (var1 + var2)))
        # ret = np.exp(ret)
        return ret
    
    ret = 0
    for i in range(n_classes):
        for j in range(n_classes):
            ret += mu_prior1[i] * mu_prior1[j] * get_integral(means1[i], covar1[i], means1[j], covar1[j])
            if mu_prior2 is not None:
                ret += mu_prior2[i] * mu_prior2[j] * get_integral(means2[i], covar2[i], means2[j], covar2[j])
                ret -= 2 * mu_prior1[i] * mu_prior2[j] * get_integral(means1[i], covar1[i], means2[j], covar2[j])
    return ret

def get_l2_loss(n_classes, mu_prior1, mu_prior2, means1, means2, covar1, covar2):
    num = _get_l2_loss(n_classes, mu_prior1, mu_prior2, means1, means2, covar1, covar2)
    den = np.sqrt(_get_l2_loss(n_classes, mu_prior1, None, means1, None, covar1, None))
    den *= np.sqrt(_get_l2_loss(n_classes, mu_prior2, None, means2, None, covar2, None))
    return num / den

def main(args):
    model_path = "saved_models/25_11_10_23/"
    #model_path = "saved_models/25_13_24_35/"
    #model_path = "saved_models/25_13_49_14/"
    #model_path = "saved_models/cross-loss/"
    all_pis = []
    all_mus = []
    all_vars = []
    
    model_label = []
    
    num_models = 0

    for model in os.listdir(model_path):
        num_models+=1
        print(model)
        model_label.append(model[5:7])
        vade = VaDE(args.in_dim, args.latent_dim, args.n_classes, covariance = model[:4])
        vade.load_state_dict(torch.load(model_path + model, map_location="cpu"))
        all_pis.append(vade.pi_prior.detach().numpy())
        if vade.covariance == "full":
            all_mus.append(vade.mu_prior.detach().numpy())
            newvar = vade.sqrt_var_prior.detach().numpy()
            newvar = [np.linalg.inv(newvar[i] @ newvar[i]) for i in range(args.n_classes)]
            all_vars.append(newvar)
        else:
            all_mus.append(vade.mu_prior.detach().numpy())
            all_vars.append(np.diag(torch.exp(vade.log_var_prior).detach().numpy()))
    
    avg_loss = 0
    all_loss = []
    for a in range(len(all_pis)):
        for b in range(len(all_pis)):
            x, y = all_mus[a], all_mus[b]
            x = x - np.mean(x, axis=0)
            y = y - np.mean(y, axis=0)

            num_comps, d = x.shape
            best_loss = np.inf
            for perm in list(itertools.permutations(list(range(num_comps)))):
                new_x = np.zeros_like(x)
                new_y = np.zeros_like(y)
                for i in range(num_comps):
                    new_x[i] = x[i]
                    new_y[i] = y[perm[i]]
                # new_x = new_x - np.mean(new_x, axis=0)
                # new_y = new_y - np.mean(new_y, axis=0)

                A = []
                for i in range(num_comps):
                    for j in range(d):
                        cur_coeffs = np.zeros((d, d))
                        for k in range(d):
                            cur_coeffs[k, j] = new_x[i, k]
                        cur_coeffs = np.append(cur_coeffs.reshape(-1), -new_y[i, j])
                        A.append(cur_coeffs)
                A = np.array(A)
                u, s, vh = np.linalg.svd(A)
                P = vh[-1]
                P = (P / P[-1])[: -1].reshape((d, d))
                best_map = P

                n_classes = args.n_classes
                means1 = np.array([x[k] @ best_map for k in range(n_classes)])
                means2 = np.array([y[perm[k]] for k in range(n_classes)])
                
                
                
                covar1 = np.array([best_map.T @ (all_vars[a][k] @ best_map) for k in range(n_classes)])            
                covar2 = np.array([all_vars[b][perm[k]] for k in range(n_classes)])

                cur_loss = get_l2_loss(n_classes = n_classes,
                                       mu_prior1 = all_pis[a],
                                       mu_prior2 = [all_pis[b][perm[k]] for k in range(n_classes)],
                                       means1 = means1,
                                       means2 = means2,
                                       covar1 = covar1,
                                       covar2 = covar2)

                if best_loss > cur_loss:
                    best_loss = cur_loss
            
            print("Loss between models {}, {} = {}".format(a, b, best_loss))
            avg_loss +=best_loss
            if (a!=b):
                all_loss.append(best_loss)
    avg_loss = avg_loss/(num_models*(num_models-1))
    
    #stda = sum([((x - avg_loss) ** 2) for x in all_loss]) / len(all_loss)
    
#     count_loss = 0
#     avg_cross_loss = 0
#     ind = 0
#     for i in range(len(model_label)):
#         for j in range(len(model_label)):
#             if i!=j:
#                 ind+=1
#             if model_label[i]!=model_label[j]:
#                 avg_cross_loss+=all_loss[ind]
#                 count_loss+=1
#     print("Average cross loss between models = {}".format(avg_cross_loss/count_loss))
            
    
    print("Average loss between models = {}".format(avg_loss))
    print("Std loss is {}".format(statistics.pstdev(all_loss)))
    
    
def draw(args):
    path_token = "25_11_10_23"
    #path_token = "25_13_24_35"
    #path_token = "25_13_49_14"
    model_path = "saved_models/{}/".format(path_token)

    all_pis = []
    all_mus = []
    all_pres = []
    
    all_models = []

    for model in os.listdir(model_path):
        print(model)
        vade = VaDE(args.in_dim, args.latent_dim, args.n_classes, covariance = model[:4])
        all_models.append(vade)
        vade.load_state_dict(torch.load(model_path + model, map_location="cpu"))
        all_pis.append(vade.pi_prior.detach().numpy())
        if vade.covariance == "full":
            all_mus.append(vade.mu_prior.detach().numpy())
            newvar = vade.sqrt_var_prior.detach().numpy()
            newvar = [(newvar[i].T @ newvar[i]) for i in range(args.n_classes)]
            all_pres.append(newvar)
        else:
            all_mus.append(vade.mu_prior.detach().numpy())
            all_pres.append(np.diag(-torch.exp(vade.log_var_prior).detach().numpy()))
            
            
    fig = plt.figure(figsize=(12, 2.2))
    num_gr = len(all_mus)
    
    #data, true_lab = make_pinwheel_data(0.3, 0.05, args.n_classes, 500, 0.25)
    data = torch.load("outputs/{}/data.pth".format(path_token)).numpy()
    true_labels = torch.load("outputs/{}/labels.pth".format(path_token)).numpy()
    plt.subplot(1, num_gr+1, 1)
    plt.scatter(data[:,1], data[:,2], c=true_labels, s=6, alpha=0.3) #for rectangles
    plt.xticks([])
    plt.yticks([])
    plt.title('Observed X', fontsize=12, family='serif')
    
    
    
    for i in range(num_gr):          
        
        lgmm = GaussianMixture(args.n_classes, weights_init = all_pis[i], means_init = all_mus[i], precisions_init = all_pres[i])
        lgmm.converged_ = True
        lgmm.weights_ = all_pis[i]/np.sum(all_pis[i])
        lgmm.means_ = all_mus[i]
        lgmm.precisions_ = all_pres[i]
        lgmm.covariances_ = np.linalg.inv(all_pres[i])
        samples, labels = lgmm.sample(100*64)
        bat = args.batch_size
        x_pred = np.vstack([vade.decode(torch.Tensor(samples[j*bat:(j+1)*bat])).detach().numpy() for j in range(len(samples)//64)])
        
        
        plt.subplot(1, num_gr+1, i+2)
        plt.scatter(samples[:,0], samples[:,1], s=6, alpha=0.3)
        plt.xticks([])
        plt.yticks([])
        plt.title('Latent GMM #{}'.format(i+1), fontsize=12, family='serif')
        


    plt.show()
        
    plt.tight_layout()
    plt.savefig('summary{}.png'.format(path_token))
    plt.close()
            
    
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=20,
                        help="number of iterations")
    parser.add_argument("--patience", type=int, default=50, 
                        help="Patience for Early Stopping")
    parser.add_argument('--lr', type=float, default=2e-4,
                        help='learning rate')
    parser.add_argument("--batch_size", type=int, default=64,
                        help="Batch size")
    parser.add_argument('--pretrain', type=int, default=1,
                        help='learning rate')
    parser.add_argument('--output_dir', type=str, default=dir,
                        help='Output dir')
    parser.add_argument("--in_dim", type=int, default=5, 
                        help="Input dimension")
    parser.add_argument("--latent_dim", type=int, default=2,
                        help="Latent dimension")
    parser.add_argument("--n_classes", type=int, default=3, 
                        help="Num classes")
    parser.add_argument("--noise_var", type=float, default=0.001, 
                        help="Variance of Gaussian noise")
    args = parser.parse_args()
    main(args)
    #draw(args)