import random
import torch
import numpy as np
import numpy.random as nr
import pandas as pd
from tqdm import tqdm
from torch import nn
from torch.nn.functional import softmax, one_hot, sigmoid, relu
from torch.distributions.beta import Beta
from torchvision import models
from pingouin import multivariate_normality as mvn
from torch.utils.data import DataLoader
import pdb
import itertools

from speech_flow_noise_gen import s_theta
from speech_flow_noise_gen import Flow

import torch.nn.functional as F
import os
from torchvision.utils import save_image
from matplotlib import pyplot as plt
import network
from sklearn.cluster import KMeans
from torchvision import transforms



device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

class AugNet(nn.Module):
    def __init__(self, device, backbone, classifier, layer_index=-1, nc=10): # base_net, train_bs
        super(AugNet,self).__init__()
        # model = eval('models.{}(pretrained=True)'.format(base_net))
        self.device = device
        self.nc = nc
        
        self.pool = [None]*nc
        self.centroids = [None]*nc
        self.global_pool = None
        self.image_pool = None
        
        # self.epsilon_level = int((32/train_bs) * 15)
        self.epsilon_level = 3
        
        
        self.global_pool_by_class = []
        self.image_pool_by_class = []
        
        self.noise_gen_pool = [None]*self.epsilon_level
        self.s_theta_pool = [None]*self.epsilon_level
        
        self.image_noise_gen = [None]*self.epsilon_level
        self.image_s_theta = [None]*self.epsilon_level
        
        self.mmd_matrix = None
        self.mmd_matrix_img = None
        
        self.image_epsilon_lists = None
        self.feature_epsilon_lists = None
        
        # for p in model.parameters(): p.requires_grad = True
        
        # model = list(model.children())  # conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc

        self.feat = backbone

       
        
    # def get_aug_net(self):
    #     params = []
    #     for i in range(self.epsilon_level):
    #         params += list(self.image_noise_gen[i].parameters())
    #         params += list(self.image_s_theta[i].parameters())
    #     return params
        

    def forward(self, inp, lbl=None):
        _, c, _ = inp.shape
        
        aug_list = [inp]
        label_list = [lbl]

        with torch.no_grad():
            for i in range(self.epsilon_level):
                noise_gen = self.image_noise_gen[i]
                s_theta_f = self.image_s_theta[i]

                z, _ = noise_gen(inp)
                eps = s_theta_f(z)
            
                aug_img = noise_gen.inverse(z + eps)

                aug_img -= aug_img.min()
                aug_img /= (aug_img.max()+1e-6)
                aug_img = aug_img.clamp(0,1)
                # aug_img = normalizing(aug_img)
                aug_img = 0.6 * aug_img + 0.4 * inp
                # aug_img = 0.8 * aug_img + 0.2*inp
                
                aug_list.append(aug_img)
                label_list.append(lbl)

            imgs_out = torch.cat(aug_list, dim=0)
            labels_out = torch.cat(label_list, dim=0)
        return imgs_out, labels_out

       




        
    # 'pool' is a Subset dataset. This is potentially a very memory-intensive routine
    def set_pool(self, pool):
        tmp = [[] for x in self.pool] # []*nc
        img = [[] for x in self.pool]
        global_data = [] # trainset
        image_data = []
        with torch.no_grad():
            # class-wise (centroids calculation) & global data (pretraining noise_gen, s_theta)
            for i, (data, label) in enumerate(pool):
                data = data.to(self.device).unsqueeze(0)
                image_data.append(data) # [1, 3, h, w]
                
                # feature = self.feat(data)
                
                # global_data.append(feature)
                # print(feature.shape) # [1, 512, 7, 7]
                img[label].append(data)
                # tmp[label].append(feature)
        
        image_data = torch.cat(image_data, dim=0)
        # global_data = torch.cat(global_data, dim=0)
        # self.global_pool = global_data.to(self.device).detach()
        self.image_pool = image_data.to(self.device).detach()
        
        # class-wise centroid calculation (feature level)
        # self.pool = [None]*self.nc
        self.img_centroids = [None]*self.nc
        self.centroids = [None]*self.nc
        for i in range(int(self.nc)):
            # print(len(tmp[i]))
            # class_feat = torch.cat(tmp[i], dim=0)
            class_img = torch.cat(img[i], dim=0)
            # self.global_pool_by_class.append(class_feat.to(self.device).detach())
            self.image_pool_by_class.append(class_img.to(self.device).detach())
            # self.pool[i] = class_feat
            # class_feat_2d = class_feat.view(class_feat.shape[0], -1) # [class size, 512*7*7]                
            class_img_2d = class_img.view(class_img.shape[0], -1)
            
            # kmeans = KMeans(n_clusters=1, random_state=0, n_init=10)
            # # kmeans.fit(class_feat_2d.cpu().numpy())
            
            # self.centroids[i] = torch.tensor(kmeans.cluster_centers_).to(self.device)
            
            # kmeans.fit(class_img_2d.cpu().numpy())
            # self.img_centroids[i] = torch.tensor(kmeans.cluster_centers_).to(self.device)
        # print(self.centroids)
        
        # mmd between classes
        self.mmd_matrix = torch.zeros((int(self.nc), int(self.nc)), device=self.device)
        self.mmd_matrix_img = torch.zeros((int(self.nc), int(self.nc)), device=self.device)
        # for i in range(int(self.nc)):
        #     for j in range(int(self.nc)):
        #         X = torch.cat(tmp[i])
        #         Y = torch.cat(tmp[j])
        #         n = X.shape[0]
        #         m = Y.shape[0]
        #         pars = [0.05]
                
        #         class_mmd = self.mmd(X, Y, pars, n, m)
        #         self.mmd_matrix[i, j] = class_mmd
        for i in range(int(self.nc)):
            for j in range(int(self.nc)):
                X = torch.cat(img[i])
                Y = torch.cat(img[j])
                n = X.shape[0]
                m = Y.shape[0]
                pars = [0.05]
                
                class_mmd = self.mmd(X, Y, pars, n, m)
                self.mmd_matrix_img[i, j] = class_mmd
        print(self.mmd_matrix, self.mmd_matrix_img)
    
    def intra_class_mmd(self, class_img, pars):
        mmd_val_list = []
        for _ in range(50):
            idx = torch.randperm(class_img.size(0))
            half  = class_img.size(0) // 2
            c1 = class_img[idx[:half]]
            c2 = class_img[idx[half:half*2]]

            mmd2 = self.mmd(c1, c2, pars, c1.size(0), c2.size(0)).item()
            mmd_val_list.append(mmd2)
        print(np.mean(mmd_val_list))
        return np.mean(mmd_val_list)
    
    
    # def intra_class_mmd(self, class_img, pars):
    #     kernel_matrix = self.make_gram(class_img, pars) # kernel gram matrix
    #     m = kernel_matrix.size(0)
    #     print(kernel_matrix.size) 
    #     diag = kernel_matrix.diagonal().sum()/m # E[k(x,x)]
    #     offdiag = (kernel_matrix.sum()-kernel_matrix.diagonal().sum())/(m*(m-1)) # E[k(x,x')]
    #     print(diag - offdiag)
    #     variance = diag - offdiag
    #     return variance.item()
    
            
    def epsilon_list(self):
        self.mmd_matrix_img.diagonal().fill_(float('inf'))
        self.mmd_matrix.diagonal().fill_(float('inf'))
        
        min_dist_img = torch.min(self.mmd_matrix_img).item()
        # #min_dist_feat = torch.min(self.mmd_matrix).item()
        
        var_img_list = []
        var_feat_list = []

        for i in range(self.nc):
            # class_feat = self.global_pool_by_class[i].view(self.global_pool_by_class[i].size(0), -1)
            class_img = self.image_pool_by_class[i].view(self.image_pool_by_class[i].size(0), -1)

            # var_feat = torch.var(class_feat, dim=0, unbiased=False).cpu().numpy()
            var_img  = torch.var(class_img, dim=0, unbiased=False).cpu().numpy()

            # var_feat_list.append(var_feat)
            var_img_list.append(var_img)

        # avg_var_feat = np.mean(var_feat_list)
        avg_var_img  = np.mean(var_img_list)

        # print(avg_var_feat)
        print(avg_var_img) # 0.043524105

        # intra-class mmd (단위 맞추기 위한 experiment)
        intra_class_mmd_list = []
        for img in self.image_pool_by_class:
            intra_class_mmd_list.append(self.intra_class_mmd(img, [0.05]))
        print(intra_class_mmd_list)
        avg_intra = np.mean(intra_class_mmd_list)
        print(avg_intra) # 0.004021623474545777
        
        # 0.001, 0.005, 0.01, 0.05, 0.1
        # mmd_threshold_img = 0.5 * (min_dist_img / (avg_var_img + 1e-8))
        mmd_threshold_img = 0.5 * (min_dist_img / (avg_intra + 1e-8))
        epsilon_list_img  = [mmd_threshold_img * (i + 1) / self.epsilon_level for i in range(self.epsilon_level)]

        print("epsilon_list_img:", epsilon_list_img)
        print(self.nc)

        self.image_epsilon_lists = epsilon_list_img




            
    # after set_pool(self,pool)
    def augment_pool(self):
        image_aug = Augmenter(self.image_pool)
        image_aug_results = image_aug.generate(input_channel=1, epsilons=self.image_epsilon_lists)
        
        with torch.no_grad():
            for i, result in enumerate(image_aug_results):
                noise_gen = Flow(input_channels=1, output_channels=1).to(device)
                s_theta_f = s_theta(input_channels=1).to(device)
                noise_gen.load_state_dict(result[4])
                s_theta_f.load_state_dict(result[5])
                self.image_noise_gen[i] = noise_gen
                self.image_s_theta[i] = s_theta_f
                # os.makedirs(os.path.join('/home/hyewon/noise_gen', str(i), exist_ok=True))
                torch.save(noise_gen, os.path.join('/home/hyewon/noise_gen/tau5/', f'{i+1}_image_noise_gen.pth'))
                torch.save(s_theta_f, os.path.join('/home/hyewon/noise_gen/tau5/', f'{i+1}_image_s_theta.pth'))
                
                
        # global_data = Augmenter(self.global_pool)
        # best_results = global_data.generate(input_channel=128, epsilons=self.feature_epsilon_lists)
        
        
        # with torch.no_grad():
        #     for i, result in enumerate(best_results):
        #         noise_gen = Flow(input_channels=128, output_channels=128).to(device)
        #         s_theta_f = s_theta(input_channels=128).to(device)
        #         noise_gen.load_state_dict(result[4])
        #         s_theta_f.load_state_dict(result[5])
        #         self.noise_gen_pool[i] = noise_gen
        #         self.s_theta_pool[i] = s_theta_f
        #         # os.makedirs(os.path.join('/home/hyewon/noise_gen', str(i), exist_ok=True))
        #         torch.save(noise_gen, os.path.join('/home/hyewon/noise_gen', f'{i+1}_feature_noise_gen.pth'))
        #         torch.save(s_theta_f, os.path.join('/home/hyewon/noise_gen', f'{i+1}_feature_s_theta.pth'))
        
    def load_augment_pool(self):
        for i in range(self.epsilon_level):
            noise_gen_path = f'/home/hyewon/noise_gen/tau/{i+1}_image_noise_gen.pth'
            s_theta_path    = f'/home/hyewon/noise_gen/tau/{i+1}_image_s_theta.pth'

            self.image_noise_gen[i] = torch.load(noise_gen_path, map_location=self.device, weights_only=False)
            self.image_s_theta[i]   = torch.load(s_theta_path, map_location=self.device, weights_only=False)

            self.image_noise_gen[i].eval()
            self.image_s_theta[i].eval()
        
    def make_gram(self, X, params, Y=None, kernel='rbf'):
        X = X.view(X.size(0), -1)
        if Y is None: Y = X
        else: Y = Y.view(Y.size(0), -1)
        n, m = X.shape[0], Y.shape[0]
        # print(X.shape, Y.shape)

        if kernel == 'rbf':
            gamma = params[0]
            # X*X => element-wisie multiplication
            # [batch_size, C, H, W] => flatten to [batch_size, C * H * W]
            C = -gamma * (X * X).sum(dim=1).unsqueeze(1).expand(n, m)
            B = 2 * gamma * torch.mm(X, Y.T)
            D = -gamma * (Y * Y).sum(dim=1).unsqueeze(1).expand(m, n)
            E = C + B + D.T
            # E = torch.clamp(E, -30, 30)
            L = torch.exp(E)  # Shape: [n, m]

        return L
    def mmd(self, X, Y, pars, n, m):
        # MMD (A.sum() / (n ** 2) - 2 * B.sum() / (n * m) + C.sum() / (m ** 2))
        A = self.make_gram(X, pars)
        B = self.make_gram(X, pars, Y)
        C = self.make_gram(Y, pars)
        #obj = A.sum() / (n ** 2) - 2 * B.sum() / (n * m) + C.sum() / (m ** 2) + \
        #        lam * (c - eps.norm())
        mmd = A.sum() / (n ** 2) -2 * B.sum() / (n * m) + C.sum() / (m ** 2) # This is a *biased* estimate!

        return mmd
        

class Augmenter:
    # X: n x d
    def __init__(self, X):
        super().__init__()
        self.X = X.to(device)
        self.X.requires_grad = False
        
        
        
    def mmd(self, X, Y, pars, n, m):
        # MMD (A.sum() / (n ** 2) - 2 * B.sum() / (n * m) + C.sum() / (m ** 2))
        A = self.make_gram(X, pars)
        B = self.make_gram(X, pars, Y)
        C = self.make_gram(Y, pars)
        #obj = A.sum() / (n ** 2) - 2 * B.sum() / (n * m) + C.sum() / (m ** 2) + \
        #        lam * (c - eps.norm())
        mmd = A.sum() / (n ** 2) -2 * B.sum() / (n * m) + C.sum() / (m ** 2) # This is a *biased* estimate!

        return mmd

    

    def generate(self, input_channel=None, lam1=1, lam2 = 0.01, lam3 = 1e-5, epsilons = [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1], lr=5e-4, epochs=100, verbose=True):
        X = self.X
        n = X.shape[0] # batch size

        pars = [0.05]
        
        best_results = []
        # bs = round(n / 6)
        
        # image
        # bs = round(n/6)
        
        # TAU
        bs = 256
        
        for epsilon in epsilons:
            early_stop = False
            noise_gen = Flow(input_channels=input_channel, output_channels=input_channel).to(device)
            s_theta_f = s_theta(input_channels=input_channel).to(device)

            opt = torch.optim.Adam(list(noise_gen.parameters()) + list(s_theta_f.parameters()), lr=lr)
            # opt = torch.optim.Adam(list(s_theta_f.parameters()), lr=lr)
            # noise_gen_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.5)
            
            last_result = None

            for epoch in range(epochs):
                opt.zero_grad()
                total_loss = 0
                total_mmd = 0

                
                perm = torch.randperm(n)
                for batch_start in range(0, n, bs):
                    batch_idx = perm[batch_start: batch_start + bs]
                    X_batch = X[batch_idx]
                    
                    

                    z, _ = noise_gen(X_batch)
                    eps = s_theta_f(z)
                    X2 = noise_gen.inverse(z + eps)
                    
                    mmd_val = self.mmd(X_batch, X2, [0.05], X_batch.size(0), X2.size(0))
                    
                    
                    noise_sum = torch.mean(eps ** 2)*lam2
                    mmd_loss = torch.abs(mmd_val - epsilon)*lam1

                    obj = -noise_sum + mmd_loss
                    obj.backward()

                    torch.nn.utils.clip_grad_norm_(list(noise_gen.parameters()) + list(s_theta_f.parameters()), max_norm=1.0)
                    # torch.nn.utils.clip_grad_norm_(list(s_theta_f.parameters()), max_norm=1.0)
                    opt.step()
                    opt.zero_grad()

                    with torch.no_grad():
                        total_loss += obj.item()
                        total_mmd += mmd_loss.item()
                    
                    del z, eps, X2, mmd_loss, obj
                    

                with torch.no_grad():
                    avg_loss = total_loss / (n // bs)
                    avg_mmd = total_mmd / (n // bs)

                if avg_mmd > 10:
                    early_stop = True
                    if last_result is not None:
                        best_results.append(last_result)
                    print(avg_mmd, noise_sum)
                    break

                
                current_result = (None, None, avg_mmd, avg_loss,
                                noise_gen.state_dict(), s_theta_f.state_dict())

                if verbose:
                    print(f"Epoch {epoch + 1}: Epsilon: {epsilon} MMD: {avg_mmd} Loss: {avg_loss} Noise_Sum: {noise_sum}")

               
                # noise_gen_scheduler.step()
                
                del avg_loss, avg_mmd
                torch.cuda.empty_cache()
                
                last_result = current_result
            if last_result is not None and not early_stop:
                best_results.append(last_result)
            torch.cuda.empty_cache()
        return best_results



    # X: n x d, Y: m x d (if Y is None, we'll use X only)
    # Return an n x m kernel matrix L, where L(i,j)=k(x_i, y_j)
    # Caution: Do not use kernel='poly'.. That code isn't ready yet.
    def make_gram(self, X, params, Y=None, kernel='rbf'):
        X = X.view(X.size(0), -1)
        if Y is None: Y = X
        else: Y = Y.view(Y.size(0), -1)
        n, m = X.shape[0], Y.shape[0]
        # print(X.shape, Y.shape)

        if kernel == 'rbf':
            gamma = params[0]
            # X*X => element-wisie multiplication
            # [batch_size, C, H, W] => flatten to [batch_size, C * H * W]
            C = -gamma * (X * X).sum(dim=1).unsqueeze(1).expand(n, m)
            B = 2 * gamma * torch.mm(X, Y.T)
            D = -gamma * (Y * Y).sum(dim=1).unsqueeze(1).expand(m, n)
            E = C + B + D.T
            E = torch.clamp(E, -30, 30)
            L = torch.exp(E)  # Shape: [n, m]

        return L



def test_normality(X, num_tests=10):
    #df = pd.DataFrame(X)
    X = torch.tensor(X)
    n = X.shape[0]
    n_samp = max(1, int(n * 0.1))
    res = mvn(X.detach().numpy(), alpha=0.05)
    print(res)
    for i in range(num_tests):
        aug = Augmenter(X)
        idx = torch.randperm(n)
        X_shuff = X[idx, :]
        Y = X_shuff
        gen, mmd_val, obj, model = aug.generate(X_shuff[:n_samp, :], epochs=40, verbose=True)
        Y[:n_samp] += gen[:n_samp].to(Y.device)
        res = mvn(Y.detach().numpy(), alpha=0.05)
        print(res)

if __name__ == "__main__":
    dim = 10 # any value higher than 20 or so will give you a NaN for the p-values
    X = nr.multivariate_normal(np.zeros(dim), np.eye(dim), 10000)
    test_normality(X)