import random
import torch
import numpy as np
import numpy.random as nr
import pandas as pd
from tqdm import tqdm
from torch import nn
from torch.nn.functional import softmax, one_hot, sigmoid, relu
from torch.distributions.beta import Beta
from torchvision import models
from pingouin import multivariate_normality as mvn
from torch.utils.data import DataLoader
import pdb
import itertools
from noise_gen_flow_based import s_theta
from noise_gen_flow_based import Flow
import torch.nn.functional as F
import os
from torchvision.utils import save_image
from matplotlib import pyplot as plt
import network
from sklearn.cluster import KMeans
from torchvision import transforms



device = torch.device("cuda:" + str(1) if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

class AugNet(nn.Module):
    def __init__(self, device, backbone, classifier, layer_index=-1, nc=10):
        super(AugNet,self).__init__()
        self.device = device
        self.nc = nc
        
        self.pool = [None]*nc
        self.centroids = [None]*nc
        self.global_pool = None
        self.image_pool = None
        
        self.epsilon_level = 15
        
        
        self.global_pool_by_class = []
        self.image_pool_by_class = []
        
        self.noise_gen_pool = [None]*self.epsilon_level
        self.s_theta_pool = [None]*self.epsilon_level
        
        self.image_noise_gen = [None]*self.epsilon_level
        self.image_s_theta = [None]*self.epsilon_level
        
        self.mmd_matrix = None
        self.mmd_matrix_img = None
        
        self.image_epsilon_lists = None
        self.feature_epsilon_lists = None

        self.feat = backbone


    def forward(self, inp, normalizing, lbl=None):
        _, c, _,_ = inp.shape
        
        aug_list = [inp]
        label_list = [lbl]

        with torch.no_grad():
            for i in range(self.epsilon_level):
                noise_gen = self.image_noise_gen[i]
                s_theta_f = self.image_s_theta[i]

                z, _ = noise_gen(inp)
                eps = s_theta_f(z)
            
                aug_img = noise_gen.inverse(z + eps)

                aug_img = torch.clamp(aug_img, 0, 1)
                aug_img = normalizing(aug_img)
                # PACS, TAU
                aug_img = 0.6 * aug_img + 0.4 * inp
                # CIFAR-10-C
                # aug_img = 0.8 * aug_img + 0.2*inp
                
                aug_list.append(aug_img)
                label_list.append(lbl)

            imgs_out = torch.cat(aug_list, dim=0)
            labels_out = torch.cat(label_list, dim=0)
        return imgs_out, labels_out

        
    # 'pool' is a Subset dataset. This is potentially a very memory-intensive routine
    def set_pool(self, pool):
        tmp = [[] for x in self.pool] # []*nc
        img = [[] for x in self.pool]
        global_data = [] # trainset
        image_data = []
        with torch.no_grad():
            # class-wise (centroids calculation) & global data (pretraining noise_gen, s_theta)
            for i, (data, label) in enumerate(pool):
                data = data.to(self.device).unsqueeze(0)
                image_data.append(data) # [1, 3, h, w]
                img[label].append(data)
        
        image_data = torch.cat(image_data, dim=0)
        self.image_pool = image_data.to(self.device).detach()
        
        # class-wise centroid calculation
        self.img_centroids = [None]*self.nc
        self.centroids = [None]*self.nc
        for i in range(int(self.nc)):
            class_img = torch.cat(img[i], dim=0)
            self.image_pool_by_class.append(class_img.to(self.device).detach())             
            class_img_2d = class_img.view(class_img.shape[0], -1)
            
        # mmd between classes
        self.mmd_matrix = torch.zeros((int(self.nc), int(self.nc)), device=self.device)
        self.mmd_matrix_img = torch.zeros((int(self.nc), int(self.nc)), device=self.device)
        for i in range(int(self.nc)):
            for j in range(int(self.nc)):
                X = torch.cat(img[i])
                Y = torch.cat(img[j])
                n = X.shape[0]
                m = Y.shape[0]
                pars = [0.05]
                
                class_mmd = self.mmd(X, Y, pars, n, m)
                self.mmd_matrix_img[i, j] = class_mmd
        print(self.mmd_matrix, self.mmd_matrix_img)
    
    def intra_class_mmd(self, class_img, pars):
        mmd_val_list = []
        for _ in range(50):
            idx = torch.randperm(class_img.size(0))
            half  = class_img.size(0) // 2
            c1 = class_img[idx[:half]]
            c2 = class_img[idx[half:half*2]]

            mmd2 = self.mmd(c1, c2, pars, c1.size(0), c2.size(0)).item()
            mmd_val_list.append(mmd2)
        print(np.mean(mmd_val_list))
        return np.mean(mmd_val_list)
    
            
    def epsilon_list(self):
        self.mmd_matrix_img.diagonal().fill_(float('inf'))
        self.mmd_matrix.diagonal().fill_(float('inf'))
        
        min_dist_img = torch.min(self.mmd_matrix_img).item()
        var_img_list = []
        var_feat_list = []

        for i in range(self.nc):
            class_img = self.image_pool_by_class[i].view(self.image_pool_by_class[i].size(0), -1)
            var_img  = torch.var(class_img, dim=0, unbiased=False).cpu().numpy()
            var_img_list.append(var_img)

        avg_var_img  = np.mean(var_img_list)

        print(avg_var_img) # 0.043524105

        # intra-class mmd
        intra_class_mmd_list = []
        for img in self.image_pool_by_class:
            intra_class_mmd_list.append(self.intra_class_mmd(img, [0.05]))
        print(intra_class_mmd_list)
        avg_intra = np.mean(intra_class_mmd_list)
        print(avg_intra) # 0.004021623474545777
        
        mmd_threshold_img = 0.5 * (min_dist_img / (avg_intra + 1e-8))
        epsilon_list_img  = [mmd_threshold_img * (i + 1) / self.epsilon_level for i in range(self.epsilon_level)]

        print("epsilon_list_img:", epsilon_list_img)
        print(self.nc)

        self.image_epsilon_lists = epsilon_list_img


    # after set_pool(self,pool)
    def augment_pool(self):
        image_aug = Augmenter(self.image_pool)
        image_aug_results = image_aug.generate(input_channel=3, epsilons=self.image_epsilon_lists)
        
        with torch.no_grad():
            for i, result in enumerate(image_aug_results):
                noise_gen = Flow(input_channels=3, output_channels=3).to(device)
                s_theta_f = s_theta(input_channels=3).to(device)
                noise_gen.load_state_dict(result[4])
                s_theta_f.load_state_dict(result[5])
                self.image_noise_gen[i] = noise_gen
                self.image_s_theta[i] = s_theta_f
                torch.save(noise_gen, os.path.join('/home/hyewon/noise_gen/final/CIFAR5/', f'{i+1}_image_noise_gen.pth'))
                torch.save(s_theta_f, os.path.join('/home/hyewon/noise_gen/final/CIFAR5/', f'{i+1}_image_s_theta.pth'))
                
       
    def load_augment_pool(self):
        for i in range(self.epsilon_level):
            noise_gen_path = f'/home/hyewon/noise_gen/final/CIFAR5/{i+1}_image_noise_gen.pth'
            s_theta_path    = f'/home/hyewon/noise_gen/final/CIFAR5/{i+1}_image_s_theta.pth'

            self.image_noise_gen[i] = torch.load(noise_gen_path, map_location=self.device, weights_only=False)
            self.image_s_theta[i]   = torch.load(s_theta_path, map_location=self.device, weights_only=False)

            self.image_noise_gen[i].eval()
            self.image_s_theta[i].eval()
        
    def make_gram(self, X, params, Y=None, kernel='rbf'):
        X = X.view(X.size(0), -1)
        if Y is None: Y = X
        else: Y = Y.view(Y.size(0), -1)
        n, m = X.shape[0], Y.shape[0]

        if kernel == 'rbf':
            gamma = params[0]
            # X*X => element-wisie multiplication
            # [batch_size, C, H, W] => flatten to [batch_size, C * H * W]
            C = -gamma * (X * X).sum(dim=1).unsqueeze(1).expand(n, m)
            B = 2 * gamma * torch.mm(X, Y.T)
            D = -gamma * (Y * Y).sum(dim=1).unsqueeze(1).expand(m, n)
            E = C + B + D.T
            # E = torch.clamp(E, -30, 30)
            L = torch.exp(E)  # Shape: [n, m]

        return L
    def mmd(self, X, Y, pars, n, m):
        # MMD (A.sum() / (n ** 2) - 2 * B.sum() / (n * m) + C.sum() / (m ** 2))
        A = self.make_gram(X, pars)
        B = self.make_gram(X, pars, Y)
        C = self.make_gram(Y, pars)
        #obj = A.sum() / (n ** 2) - 2 * B.sum() / (n * m) + C.sum() / (m ** 2) + \
        #        lam * (c - eps.norm())
        mmd = A.sum() / (n ** 2) -2 * B.sum() / (n * m) + C.sum() / (m ** 2) # This is a *biased* estimate!

        return mmd
        

class Augmenter:
    # X: n x d
    def __init__(self, X):
        super().__init__()
        self.X = X.to(device)
        self.X.requires_grad = False
        
    def mmd(self, X, Y, pars, n, m):
        # MMD (A.sum() / (n ** 2) - 2 * B.sum() / (n * m) + C.sum() / (m ** 2))
        A = self.make_gram(X, pars)
        B = self.make_gram(X, pars, Y)
        C = self.make_gram(Y, pars)
        #obj = A.sum() / (n ** 2) - 2 * B.sum() / (n * m) + C.sum() / (m ** 2) + \
        #        lam * (c - eps.norm())
        mmd = A.sum() / (n ** 2) -2 * B.sum() / (n * m) + C.sum() / (m ** 2) # This is a *biased* estimate!

        return mmd

    

    def generate(self, input_channel=None, lam1=1, lam2 = 0.1, lam3 = 1e-5, epsilons = [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1], lr=5e-4, epochs=100, verbose=True):
        X = self.X
        n = X.shape[0] # batch size

        pars = [0.05]
        
        best_results = []
         
        bs = 512
        
        for epsilon in epsilons:
            early_stop = False
            noise_gen = Flow(input_channels=input_channel, output_channels=input_channel).to(device)
            s_theta_f = s_theta(input_channels=input_channel).to(device)

            opt = torch.optim.Adam(list(noise_gen.parameters()) + list(s_theta_f.parameters()), lr=lr)
            
            last_result = None

            for epoch in range(epochs):
                opt.zero_grad()
                total_loss = 0
                total_mmd = 0

                
                perm = torch.randperm(n)
                for batch_start in range(0, n, bs):
                    batch_idx = perm[batch_start: batch_start + bs]
                    X_batch = X[batch_idx]
                    
                    

                    z, _ = noise_gen(X_batch)
                    eps = s_theta_f(z)*1e-1
                    X2 = noise_gen.inverse(z + eps)
                    
                    mmd_val = self.mmd(X_batch, X2, [0.05], X_batch.size(0), X2.size(0))
                    
                    
                    noise_sum = torch.mean(eps ** 2)*lam2
                    mmd_loss = torch.abs(mmd_val - epsilon)*lam1

                    obj = -noise_sum + mmd_loss
                    obj.backward()

                    torch.nn.utils.clip_grad_norm_(list(noise_gen.parameters()) + list(s_theta_f.parameters()), max_norm=1.0)
                    opt.step()
                    opt.zero_grad()

                    with torch.no_grad():
                        total_loss += obj.item()
                        total_mmd += mmd_loss.item()
                    
                    del z, eps, X2, mmd_loss, obj
                    

                with torch.no_grad():
                    avg_loss = total_loss / (n // bs)
                    avg_mmd = total_mmd / (n // bs)

                if avg_mmd > 10:
                    early_stop = True
                    best_results.append(current_result)
                    break
                
                current_result = (None, None, avg_mmd, avg_loss,
                                noise_gen.state_dict(), s_theta_f.state_dict())

                if verbose:
                    print(f"Epoch {epoch + 1}: Epsilon: {epsilon} MMD: {avg_mmd} Loss: {avg_loss} Noise_Sum: {noise_sum}")
                
                del avg_loss, avg_mmd
                torch.cuda.empty_cache()
                
                last_result = current_result
            if last_result is not None and not early_stop:
                best_results.append(last_result)
            torch.cuda.empty_cache()
        return best_results



    # X: n x d, Y: m x d (if Y is None, we'll use X only)
    # Return an n x m kernel matrix L, where L(i,j)=k(x_i, y_j)
    # Caution: Do not use kernel='poly'.. That code isn't ready yet.
    def make_gram(self, X, params, Y=None, kernel='rbf'):
        X = X.view(X.size(0), -1)
        if Y is None: Y = X
        else: Y = Y.view(Y.size(0), -1)
        n, m = X.shape[0], Y.shape[0]
        # print(X.shape, Y.shape)

        if kernel == 'rbf':
            gamma = params[0]
            # X*X => element-wisie multiplication
            # [batch_size, C, H, W] => flatten to [batch_size, C * H * W]
            C = -gamma * (X * X).sum(dim=1).unsqueeze(1).expand(n, m)
            B = 2 * gamma * torch.mm(X, Y.T)
            D = -gamma * (Y * Y).sum(dim=1).unsqueeze(1).expand(m, n)
            E = C + B + D.T
            E = torch.clamp(E, -30, 30)
            L = torch.exp(E)  # Shape: [n, m]

        return L



def test_normality(X, num_tests=10):
    X = torch.tensor(X)
    n = X.shape[0]
    n_samp = max(1, int(n * 0.1))
    res = mvn(X.detach().numpy(), alpha=0.05)
    print(res)
    for i in range(num_tests):
        aug = Augmenter(X)
        idx = torch.randperm(n)
        X_shuff = X[idx, :]
        Y = X_shuff
        gen, mmd_val, obj, model = aug.generate(X_shuff[:n_samp, :], epochs=40, verbose=True)
        Y[:n_samp] += gen[:n_samp].to(Y.device)
        res = mvn(Y.detach().numpy(), alpha=0.05)
        print(res)

if __name__ == "__main__":
    dim = 10 # any value higher than 20 or so will give you a NaN for the p-values
    X = nr.multivariate_normal(np.zeros(dim), np.eye(dim), 10000)
    test_normality(X)