import sys
import pickle
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
import scipy.linalg as scilin

from models.sup_con_original import SupConResNet, SupCEResNet, LinearClassifier
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10, MNIST
from data_loader.mini_imagenet import MiniImagenet
from utils import load_from_state_dict, CDNV
from define_data import get_data


def compute_accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def parse_eval_args():
    parser = argparse.ArgumentParser()

    # parameters
    # Model Selection
    parser.add_argument('--model', type=str, default='resnet18') # Model type
    parser.add_argument('--model_name', type=str, default='resnet18') # Model type
    parser.add_argument('--head_type', type=str, default='mlp') # projection head type for this model
    parser.add_argument('--val_choice', type=str, default='feature') # validation which part of the model
    parser.add_argument('--after_ft', dest='after_ft', action='store_true') # Validate model after fine tune
    parser.add_argument('--no-bias', dest='bias', action='store_false')
    parser.add_argument('--ETF_fc', dest='ETF_fc', action='store_true')
    parser.add_argument('--fixdim', dest='fixdim', type=int, default=0)
    parser.add_argument('--SOTA', dest='SOTA', action='store_true')
    parser.add_argument("--rm", dest='remove_last_relu', action="store_true")
    
    # MLP settings (only when using mlp and res_adapt(in which case only width has effect))
    parser.add_argument('--width', type=int, default=1024)
    parser.add_argument('--depth', type=int, default=6)

    # Hardware Setting
    parser.add_argument('--gpu_id', type=int, default=0)

    # Directory Setting
    parser.add_argument('--dataset', type=str, choices=['mnist', 'cifar10', 'cifar100', 'cifar10_random', 'miniimagenet'], default='cifar10')
    parser.add_argument('--data_dir', type=str, default='/scratch/qingqu_root/qingqu1/xlxiao/DL/data')
    parser.add_argument('--load_path', type=str, default="saved/")
    parser.add_argument('--p_name', type=str, default="info.pkl")

    # Learning Options
    parser.add_argument('--epochs', type=int, default=199, help='Max Epochs')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--sample_size', type=int, default=None, help='sample size PER CLASS')

    args = parser.parse_args()

    return args

def power_iteration(mat, iters=50):
    m,n = mat.shape
    u_prev = torch.zeros(m,1)
    u_cur = torch.randn(m,1)
    for it in range(iters):
        if torch.sum((u_prev - u_cur)**2) < 0.001:
            print(it)
            break
        u_prev = u_cur
        v = mat.T @ u_cur
        v = v / torch.linalg.norm(v)
        u_cur = mat @ v
        u_cur = u_cur / torch.linalg.norm(u_cur)
    
    return (u_cur.T @ mat) @ v

def compute_numerical_rank(all_features): # newly defined
    # Change this function so it returns nf_metric and singular values
    
    #all_features = info_pkl['before_class_dict_train'] # all features should be this
    nf_metric_dict = {}
    spec_norm_dict = {}
    fro_norm_dict = {}

    for y in all_features: # iterate through class
        #class_feature = torch.stack(all_features[y], 0)
        class_feature = np.vstack(all_features[y])
        class_feature = torch.from_numpy(class_feature)
        print(class_feature.shape)
        two_norm = power_iteration(class_feature)
        spec_norm_dict[y] = two_norm 
        fro_norm = torch.sqrt(torch.sum(class_feature ** 2))
        fro_norm_dict[y] = fro_norm 
        nf_metric_dict[y] = (fro_norm ** 2) / (two_norm ** 2)

    return nf_metric_dict, spec_norm_dict, fro_norm_dict

def split_array(input_array, batchsize=128):
    input_size = input_array.shape[0]
    num_splits, res_splits = input_size // batchsize, input_size % batchsize
    output_array_list = list()
    if res_splits == 0:
        output_array_list = np.split(input_array, batchsize, axis=0)
    else:
        for i in range(num_splits):
            output_array_list.append(input_array[i*batchsize:(i+1)*batchsize])

        output_array_list.append(input_array[num_splits*batchsize:])

    return output_array_list

def compute_info(args, model, dataloader):
    model_train, model_ft = model
    num_data = 0
    mu_G = 0
    mu_c_dict = dict()
    num_class_dict = dict()
    before_class_dict = dict()
    after_class_dict = dict()
    top1 = AverageMeter()
    top5 = AverageMeter()
    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(args.device), targets.to(args.device)

        with torch.no_grad():
            features = model_train.encoder(inputs)
            
            if model_train.loss_type == "CE":
                # Means we are using CE loss to train
                #head = model_train.head(features)
                head = model_train(inputs)
                outputs = model_ft(head)
            elif model_train.loss_type == "SupCon":
                # Supervised contrastive loss
                outputs = model_ft(features)
            
            if args.val_choice == "feature":
                features = features
            elif args.val_choice == "all":
                features = model_train(inputs)
            else:
                raise ValueError("val_choice specified not supported")
        
        # Need to normalize feature
        features = F.normalize(features, dim=1)
        # Need to normalize feature
        
        if batch_idx == 0:
            print(f"Feature shape: {features.shape}")
        mu_G += torch.sum(features, dim=0)

        for b in range(len(targets)):
            y = targets[b].item()
            if y not in mu_c_dict:
                mu_c_dict[y] = features[b, :]
                before_class_dict[y] = [features[b, :].detach().cpu().numpy()]
                after_class_dict[y] = [outputs[b, :].detach().cpu().numpy()]
                num_class_dict[y] = 1
            else:
                mu_c_dict[y] += features[b, :]
                before_class_dict[y].append(features[b, :].detach().cpu().numpy())
                after_class_dict[y].append(outputs[b, :].detach().cpu().numpy())
                num_class_dict[y] = num_class_dict[y] + 1

        num_data += targets.shape[0]

        prec1, prec5 = compute_accuracy(outputs.data, targets.data, topk=(1, 5))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

    mu_G /= num_data
    for i in range(len(mu_c_dict.keys())):
        mu_c_dict[i] /= num_class_dict[i]

    return mu_G, mu_c_dict, before_class_dict, after_class_dict, top1.avg, top5.avg

def compute_Sigma_W(args, before_class_dict, mu_c_dict, batchsize=128):
    num_data = 0
    Sigma_W = 0

    for target in before_class_dict.keys():
        class_feature_list = split_array(np.array(before_class_dict[target]), batchsize=batchsize)
        for features in class_feature_list:
            features = torch.from_numpy(features).to(args.device)
            Sigma_W_batch = (features - mu_c_dict[target].unsqueeze(0)).unsqueeze(2) @ (
                        features - mu_c_dict[target].unsqueeze(0)).unsqueeze(1)
            Sigma_W += torch.sum(Sigma_W_batch, dim=0)
            num_data += features.shape[0]

    Sigma_W /= num_data
    return Sigma_W.detach().cpu().numpy()


def compute_Sigma_B(mu_c_dict, mu_G):
    Sigma_B = 0
    K = len(mu_c_dict)
    for i in range(K):
        Sigma_B += (mu_c_dict[i] - mu_G).unsqueeze(1) @ (mu_c_dict[i] - mu_G).unsqueeze(0)

    Sigma_B /= K

    return Sigma_B.cpu().numpy()

def compute_ETF(W):
    K = W.shape[0]
    WWT = torch.mm(W, W.T)
    WWT /= torch.norm(WWT, p='fro')

    sub = (torch.eye(K) - 1 / K * torch.ones((K, K))).cuda() / pow(K - 1, 0.5)
    ETF_metric = torch.norm(WWT - sub, p='fro')
    return ETF_metric.detach().cpu().numpy().item()

def compute_ETF_feature(mu_c_dict, mu_G): # fnm = normalize(feature_mean), fnm - avg(fnm)
    """
    args:
    @ mu_c_dict: dictionary of class feature mean
    @ mu_G: Global mean of features
    Both of the above parameter could be obtained from the compute_info function
    """
    device = mu_G.device
    classes = list(mu_c_dict.keys())
    K = len(classes)
    fea_len = mu_c_dict[classes[0]].shape[0]
    
    H_bar = torch.zeros(K,fea_len).to(device)
    for i, k in enumerate(mu_c_dict):
        #print(k, torch.linalg.norm(mu_c_dict[k]))
        #### Added to control the norm of feature mean for all classes to be the same
        feature_class_mean = mu_c_dict[k] - mu_G # Subtract global mean from class mean
        feature_class_mean = feature_class_mean / torch.linalg.norm(feature_class_mean)
        #### Added end
        H_bar[i] = feature_class_mean
    
    # subtract global mean removed
    #mu_g = torch.mean(H_bar, 0).view(1,-1)
    #H_bar = H_bar - mu_g
    
    HHT = torch.mm(H_bar, H_bar.T)
    HHT /= torch.norm(HHT, p='fro')

    sub = (torch.eye(K) - 1 / K * torch.ones((K, K))).to(device) / pow(K - 1, 0.5)
    
    ETF_metric_tilde = torch.norm(HHT - sub, p='fro')
    return ETF_metric_tilde.detach().cpu().numpy().item()


def main():
    args = parse_eval_args()

    if args.load_path is None:
        sys.exit('Need to input the path to a pre-trained model!')

    #device = torch.device("cuda:"+str(args.gpu_id) if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.device = device

    # Dataset part
    print(f"Using dataset {args.dataset}")
    print()
    trainloader, testloader, num_classes = get_data(args.dataset, args.data_dir, args.batch_size, do_transform = False)
    
    print()
    print(f"We are validating using model type {args.model} and projection head type {args.head_type}, validate the {args.val_choice} part of the model")
    print()
    
    if args.model == "supce":
        model_train = SupCEResNet(name=args.model_name, head=args.head_type, remove_last_relu=args.remove_last_relu).to(device)
    else:
        raise ValueError(f"Model type {args.model} not supported")
    
    model_ft = LinearClassifier(name=args.model_name, num_classes=num_classes).to(device)


    info_dict = {
                 'cdnv': [],
                 'cdnv_test': [],
                 'collapse_metric': [],
                 'collapse_metric_test': [],
                 'numerical_rank': [],
                 'ETF_metric': [],
                 'ETF_feature_metric': [],
                 'ETF_feature_metric_test': [],
                 }
    
    print(args.load_path)
    for i in range(args.epochs, args.epochs+1, 2):
    #for i in range(args.epochs, args.epochs+1, 2):
        if args.after_ft:
            print(f"Load model from {args.load_path}/model_best.pth")
            checkpoint = torch.load(args.load_path + 'model_best' + '.pth', map_location=device)
        else:
            checkpoint = torch.load(args.load_path + 'model_epoch_' + str(i) + '.pth', map_location=device)
        model_train.load_state_dict(checkpoint['state_dict'])

        if args.after_ft:
            model_ft.load_state_dict(checkpoint['state_dict_ft'])
        model_train.eval()
        model_ft.eval()

        for n, p in model_ft.named_parameters():
            if 'linear.weight' in n:
                W = p.clone()
            if 'linear.bias' in n:
                b = p.clone()

        if not args.bias:
            b = torch.zeros((W.shape[0],), device=device)

        mu_G_train, mu_c_dict_train, before_class_dict_train, after_class_dict_train, train_acc1, train_acc5 = compute_info(args, [model_train, model_ft], trainloader)
        mu_G_test, mu_c_dict_test, before_class_dict_test, after_class_dict_test, test_acc1, test_acc5 = compute_info(args, [model_train, model_ft], testloader)
        
        nr_metric_dict, spec_norm_dict, fro_norm_dict = compute_numerical_rank(before_class_dict_train)
        info_dict['numerical_rank'].append(nr_metric_dict) # Numerical rank
        
        # CDNV
        cdnv = CDNV(before_class_dict_train)
        info_dict['cdnv'].append(cdnv)
        cdnv_test = CDNV(before_class_dict_test)
        info_dict['cdnv_test'].append(cdnv_test)

        Sigma_W = compute_Sigma_W(args, before_class_dict_train, mu_c_dict_train, batchsize=args.batch_size)
        # Sigma_W_test_norm = compute_Sigma_W(args, model, fc_features, mu_c_dict_train, testloader, isTrain=False)
        Sigma_W_test = compute_Sigma_W(args, before_class_dict_test, mu_c_dict_test, batchsize=args.batch_size)
        Sigma_B = compute_Sigma_B(mu_c_dict_train, mu_G_train)
        Sigma_B_test = compute_Sigma_B(mu_c_dict_test, mu_G_test)

        collapse_metric = np.trace(Sigma_W @ scilin.pinv(Sigma_B)) / len(mu_c_dict_train)
        collapse_metric_test = np.trace(Sigma_W_test @ scilin.pinv(Sigma_B_test)) / len(mu_c_dict_test)
        print("NC1", collapse_metric)
        print("NC1_test", collapse_metric_test)

        ETF_metric = compute_ETF(W)
        ETF_metric_tilde = compute_ETF_feature(mu_c_dict_train, mu_G_train)
        ETF_metric_tilde_test = compute_ETF_feature(mu_c_dict_test, mu_G_test) 
        print(f"ETF metric for feature: {ETF_metric_tilde}")
        print(f"Test ETF metric for feature: {ETF_metric_tilde_test}")
        
        
        info_dict['collapse_metric'].append(collapse_metric)
        info_dict['collapse_metric_test'].append(collapse_metric_test)
        info_dict['ETF_metric'].append(ETF_metric)
        info_dict['ETF_feature_metric'].append(ETF_metric_tilde) # ETF for feature mean
        info_dict['ETF_feature_metric_test'].append(ETF_metric_tilde_test) # ETF for feature mean test
        
        print(f"Epoch {i} is processed")
        
    with open(args.load_path + args.p_name, 'wb') as f: #'info_normal.pkl'
        pickle.dump(info_dict, f)



if __name__ == "__main__":
    main()