import sys
from pathlib import Path
import pickle
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
import scipy.linalg as scilin

import argparse
import os
import numpy as np
from define_data import get_data
from model import VisionTransformer
from config import get_val_nc_config
from checkpoint import load_checkpoint

class ForwardHook():
    def __init__(self, module):
        self.hook = module.register_forward_pre_hook(self.hook_fn)
        self.outputs = []
    def hook_fn(self, module, module_in):
        self.outputs.append(module_in)
    def clear(self):
        self.outputs = []
    def close(self):
        self.hook.remove()

def split_array(input_array, batchsize=128):
    input_size = input_array.shape[0]
    num_splits, res_splits = input_size // batchsize, input_size % batchsize
    output_array_list = list()
    if res_splits == 0:
        output_array_list = np.split(input_array, batchsize, axis=0)
    else:
        for i in range(num_splits):
            output_array_list.append(input_array[i*batchsize:(i+1)*batchsize])

        output_array_list.append(input_array[num_splits*batchsize:])

    return output_array_list


def compute_info(config, model, hook, dataloader):
    num_data = 0
    mu_G = 0
    mu_c_dict = dict()
    num_class_dict = dict()
    before_class_dict = dict()
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(config.device), targets.to(config.device)

        with torch.no_grad():
            outputs = model(inputs)
        
        ##### Change if want to calculate NC on features after pooling #####
        features = hook.outputs[0][0]
        if len(features.shape) == 3:
            features = features.transpose(1,2).unsqueeze(2) # bs * c * 1 * image size
            features = F.adaptive_avg_pool2d(features, (1,1)) # bs * c * 1 * 1
            features = features.view(len(targets), -1)
        else:
            features = features.view(len(targets), -1)
        ##### Change if want to calculate NC on features after pooling #####
            
        #features = hook.outputs[0][0].view(len(targets), -1)
        # Need to normalize feature
        features = F.normalize(features, dim=1)
        # Need to normalize feature
        hook.clear()

        mu_G += torch.sum(features, dim=0).detach()#.cpu()

        for b in range(len(targets)):
            y = targets[b].item()
            if y not in mu_c_dict:
                mu_c_dict[y] = features[b, :].detach()#.cpu()
                before_class_dict[y] = [features[b, :].detach().cpu().numpy()]
                num_class_dict[y] = 1
            else:
                mu_c_dict[y] += features[b, :].detach()#.cpu()
                before_class_dict[y].append(features[b, :].detach().cpu().numpy())
                num_class_dict[y] = num_class_dict[y] + 1
        ### End the process

        num_data += targets.shape[0]

    mu_G /= num_data
    for i in range(len(mu_c_dict.keys())):
        mu_c_dict[i] /= num_class_dict[i]

    return mu_G, mu_c_dict, before_class_dict

def compute_Sigma_W(config, before_class_dict, mu_c_dict, batchsize=128):
    num_data = 0
    Sigma_W = 0

    for target in before_class_dict.keys():
        class_feature_list = split_array(np.array(before_class_dict[target]), batchsize=batchsize)
        for features in class_feature_list:
            features = torch.from_numpy(features).to(config.device)
            Sigma_W_batch = (features - mu_c_dict[target].unsqueeze(0)).unsqueeze(2) @ (
                        features - mu_c_dict[target].unsqueeze(0)).unsqueeze(1)
            Sigma_W += torch.sum(Sigma_W_batch, dim=0)
            num_data += features.shape[0]

    Sigma_W /= num_data
    return Sigma_W.detach().cpu().numpy()


def compute_Sigma_B(mu_c_dict, mu_G):
    Sigma_B = 0
    K = len(mu_c_dict)
    for i in range(K):
        Sigma_B += (mu_c_dict[i] - mu_G).unsqueeze(1) @ (mu_c_dict[i] - mu_G).unsqueeze(0)

    Sigma_B /= K

    return Sigma_B.cpu().numpy()

def CDNV(all_class_dict):
    class_num = len(all_class_dict.keys())
    var_list = []
    mean_list = []
    for cla in all_class_dict:
        this_class_feature = np.vstack(all_class_dict[cla])
        mu_Q = np.mean(this_class_feature, axis = 0)
        class_var_all = np.linalg.norm(this_class_feature - mu_Q[None,:], axis = 1)
        class_var = np.mean(class_var_all)
        mean_list.append(mu_Q)
        var_list.append(class_var)
    all_cdnv = []
    for i in range(len(var_list)):
        mean_Q1 = mean_list[i]
        var_Q1 = var_list[i]
        for j in range(i+1, len(var_list)):
            mean_Q2 = mean_list[j]
            var_Q2 = var_list[j]
            cdnv = (var_Q1 + var_Q2) / (2 * np.linalg.norm(mean_Q1 - mean_Q2) ** 2)
            all_cdnv.append(cdnv)
    mean_cdnv = np.mean(all_cdnv)
    
    return mean_cdnv

def main():
    config = get_val_nc_config()

    if not config.vanilla and config.checkpoint_path is None:
        sys.exit('Need to input the path to a pre-trained model!')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    config.device = device

    
    # Dataset part
    print(f"Using dataset {config.dataset}")
    print()
    trainloader, testloader, num_classes = get_data(config.dataset, config.data_dir, config.image_size, config.batch_size, do_transform = False)
    
    # create model
    print("create model")
    model = VisionTransformer(
             image_size=(config.image_size, config.image_size),
             patch_size=(config.patch_size, config.patch_size),
             emb_dim=config.emb_dim,
             mlp_dim=config.mlp_dim,
             num_heads=config.num_heads,
             num_layers=config.num_layers,
             num_classes=num_classes,
             attn_dropout_rate=config.attn_dropout_rate,
             dropout_rate=config.dropout_rate)

    # load checkpoint
    if config.vanilla:
        print("Use untrained model to validate NC!")
    else:
        if config.checkpoint_path:
            state_dict = load_checkpoint(config.checkpoint_path + config.pretrain_model_name)
            if num_classes != state_dict['classifier.weight'].size(0):
                del state_dict['classifier.weight']
                del state_dict['classifier.bias']
                print("re-initialize fc layer")
                model.load_state_dict(state_dict, strict=False)
            else:
                model.load_state_dict(state_dict)
            print(f"Load pretrained weights from {config.checkpoint_path}/{config.pretrain_model_name}")

    # send model to device
    model = model.to(device)
    model.eval()
    
    # Create a container to store NC information
    info_dict = {
                 'collapse_metric': [],
                 'cdnv': []
                 }
    
    print("Validating All Layers!")
    for layer_num in range(1,13): #13
        if layer_num > 1:
            layer_hook.close()
            
        if layer_num == 1:
            layer_hook = ForwardHook(model.transformer.encoder_layers[1].norm1)
        elif layer_num == 2:
            layer_hook = ForwardHook(model.transformer.encoder_layers[2].norm1)
        elif layer_num == 3:
            layer_hook = ForwardHook(model.transformer.encoder_layers[3].norm1)
        elif layer_num == 4:
            layer_hook = ForwardHook(model.transformer.encoder_layers[4].norm1)
        elif layer_num == 5:
            layer_hook = ForwardHook(model.transformer.encoder_layers[5].norm1)
        elif layer_num == 6:
            layer_hook = ForwardHook(model.transformer.encoder_layers[6].norm1)
        elif layer_num == 7:
            layer_hook = ForwardHook(model.transformer.encoder_layers[7].norm1)
        elif layer_num == 8:
            layer_hook = ForwardHook(model.transformer.encoder_layers[8].norm1)
        elif layer_num == 9:
            layer_hook = ForwardHook(model.transformer.encoder_layers[9].norm1)
        elif layer_num == 10:
            layer_hook = ForwardHook(model.transformer.encoder_layers[10].norm1)
        elif layer_num == 11:
            layer_hook = ForwardHook(model.transformer.encoder_layers[11].norm1)
        elif layer_num == 12:
            layer_hook = ForwardHook(model.classifier)
            
        # Now start to find cdnv for the associated layer
    
        print("collect info!")
        mu_G_train, mu_c_dict_train, before_class_dict_train = compute_info(config, model, layer_hook, trainloader)
        mu_G_test, mu_c_dict_test, before_class_dict_test = compute_info(config, model, layer_hook, testloader)

        print("here we go!")
        # NC1 
        #Sigma_W = compute_Sigma_W(config, before_class_dict_train, mu_c_dict_train, batchsize=config.batch_size)
        #Sigma_B = compute_Sigma_B(mu_c_dict_train, mu_G_train)

        #collapse_metric = np.trace(Sigma_W @ scilin.pinv(Sigma_B)) / len(mu_c_dict_train)
        #info_dict['collapse_metric'].append(collapse_metric)
        # CDNV
        cdnv = CDNV(before_class_dict_train)
        info_dict['cdnv'].append(cdnv)
    
    # Store the information
    print(info_dict)
    save_dir = Path(config.checkpoint_path + "/" + config.dataset + "_nc/")
    save_dir.mkdir(parents=True, exist_ok=True)
    if config.vanilla:
        with open(str(save_dir) + "/info_vanilla.pkl", 'wb') as f: 
            pickle.dump(info_dict, f)
    else:
        with open(str(save_dir) + "/info_new.pkl", 'wb') as f: 
            pickle.dump(info_dict, f)


if __name__ == "__main__":
    main()