import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.linalg as scilin
from torchvision.models import resnet18, resnet50
import argparse
import numpy as np
import matplotlib.pyplot as plt
from define_data import get_data
from utils import CDNV

class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(LinearClassifier, self).__init__()

        if name == 'resnet18':
            feat_dim = 512
        elif name == "resnet50":
            feat_dim = 2048
        else:
            raise ValueError("Please use ResNet18 or ResNet50!")
            
        self.linear = nn.Linear(feat_dim, num_classes, bias = True)

    def forward(self, features):
        return self.linear(features)
        
class ForwardPreHook():
    def __init__(self, module):
        self.hook = module.register_forward_pre_hook(self.hook_fn)
        self.outputs = []
    def hook_fn(self, module, module_in):
        self.outputs.append(module_in)
    def clear(self):
        self.outputs = []
    def close(self):
        self.hook.remove()
        
class ForwardHook():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.outputs = []
    def hook_fn(self, module, module_in, module_out):
        self.outputs.append(module_out)
    def clear(self):
        self.outputs = []
    def close(self):
        self.hook.remove()

def compute_accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def parse_eval_args():
    parser = argparse.ArgumentParser()

    # parameters
    # Model Selection
    parser.add_argument('--model', type=str, default='resnet18') # Model type
    parser.add_argument('--after_ft', dest='after_ft', action='store_true') # Validate model after fine tune
    parser.add_argument('--int_layers', nargs='*', help="All layers that will be finetuned") 

    # Hardware Setting
    parser.add_argument('--gpu_id', type=int, default=0)

    # Directory Setting
    parser.add_argument('--dataset', type=str, choices=['cifar10','cifar100','pet','dtd','aircraft'], default='cifar10')
    parser.add_argument('--data_dir', type=str, default='<path to folder where data should be saved>')
    parser.add_argument('--save_path', type=str, default="saved/")
    parser.add_argument('--load_path', type=str, default=None)

    # Learning Options
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')

    args = parser.parse_args()

    return args

def split_array(input_array, batchsize=128):
    input_size = input_array.shape[0]
    num_splits, res_splits = input_size // batchsize, input_size % batchsize
    output_array_list = list()
    if res_splits == 0:
        output_array_list = np.split(input_array, batchsize, axis=0)
    else:
        for i in range(num_splits):
            output_array_list.append(input_array[i*batchsize:(i+1)*batchsize])

        output_array_list.append(input_array[num_splits*batchsize:])

    return output_array_list

def compute_info(args, model, layer_hook_list, dataloader):
    penulti_length = 2048 if args.model == "resnet50" else 512
    model_train, model_ft = model
    hook = ForwardPreHook(model_ft.linear)
    num_data = 0
    mu_G = 0
    mu_c_dict = dict()
    num_class_dict = dict()
    before_class_dict = dict()
    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(args.device), targets.to(args.device)

        with torch.no_grad():
            feature_before_fc = model_train(inputs)
        
            # Need to normalize feature
            #features = F.normalize(features, dim=1)
            # Need to normalize feature
        
            all_skip_connect_sum = 0
            for i in range(len(args.layer_to_change)):
                layer_hook = layer_hook_list[i]
                idx_layer_out = layer_hook.outputs[0]
                # Do the skip connection
                pooled_layer_out = F.adaptive_avg_pool2d(idx_layer_out, (1,1)).squeeze()
                out_length = pooled_layer_out.shape[1] # Get length
                skip_connect_out = F.pad(pooled_layer_out, [0, penulti_length-out_length,0,0])
                all_skip_connect_sum += skip_connect_out 
                layer_hook.clear()

            feature_skip = (feature_before_fc + all_skip_connect_sum) / (len(args.layer_to_change) + 1)
            
            outputs = model_ft(feature_skip) ## Debug
            
            # features = feature_skip
            # print(features)
            
            features = hook.outputs[0][0].view(len(targets), -1) ## Debug
            hook.clear()
        
        if batch_idx == 0:
            print(f"Feature shape: {features.shape}")
        mu_G += torch.sum(features, dim=0)

        for b in range(len(targets)):
            y = targets[b].item()
            if y not in mu_c_dict:
                mu_c_dict[y] = features[b, :]
                before_class_dict[y] = [features[b, :].detach().cpu().numpy()]
                num_class_dict[y] = 1
            else:
                mu_c_dict[y] += features[b, :]
                before_class_dict[y].append(features[b, :].detach().cpu().numpy())
                num_class_dict[y] = num_class_dict[y] + 1

        num_data += targets.shape[0]

    mu_G /= num_data
    for i in range(len(mu_c_dict.keys())):
        mu_c_dict[i] /= num_class_dict[i]

    return mu_G, mu_c_dict, before_class_dict

def compute_Sigma_W(args, before_class_dict, mu_c_dict, batchsize=128):
    num_data = 0
    Sigma_W = 0

    for target in before_class_dict.keys():
        class_feature_list = split_array(np.array(before_class_dict[target]), batchsize=batchsize)
        for features in class_feature_list:
            features = torch.from_numpy(features).to(args.device)
            Sigma_W_batch = (features - mu_c_dict[target].unsqueeze(0)).unsqueeze(2) @ (
                        features - mu_c_dict[target].unsqueeze(0)).unsqueeze(1)
            Sigma_W += torch.sum(Sigma_W_batch, dim=0)
            num_data += features.shape[0]

    Sigma_W /= num_data
    return Sigma_W.detach().cpu().numpy()

def compute_Sigma_B(mu_c_dict, mu_G):
    Sigma_B = 0
    K = len(mu_c_dict)
    for i in range(K):
        Sigma_B += (mu_c_dict[i] - mu_G).unsqueeze(1) @ (mu_c_dict[i] - mu_G).unsqueeze(0)

    Sigma_B /= K

    return Sigma_B.cpu().numpy()

def main():
    args = parse_eval_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.device = device

    # Dataset part
    print(f"Using dataset {args.dataset}")
    print()
    trainloader, testloader, num_classes = get_data(args.dataset, args.data_dir, 
                                                    args.batch_size, do_transform = False)
    
    print()
    print(f"We are validating using model type {args.model}!")
    print()
    
    if args.model == "resnet18":
        model_train = resnet18(weights='IMAGENET1K_V1').to(device)
        model_train.fc = nn.Sequential()
    elif args.model == "resnet50":
        model_train = resnet50(weights='IMAGENET1K_V1').to(device)
        model_train.fc = nn.Sequential()
    else:
        raise ValueError(f"Model type {args.model} not supported")
        
    model_ft = LinearClassifier(name=args.model, num_classes=num_classes).to(device)
    
    if args.model == "resnet18":
        model_part_dict = {"inp_layer":[model_train.maxpool, "conv1"],
                           "l1_b1": [model_train.layer1[0], "layer1.0"], 
                           "l1_b2": [model_train.layer1[1], "layer1.1"],
                           "l2_b1": [model_train.layer2[0], "layer2.0"], 
                           "l2_b2": [model_train.layer2[1], "layer2.1"], 
                           "l3_b1": [model_train.layer3[0], "layer3.0"], 
                           "l3_b2": [model_train.layer3[1], "layer3.1"], 
                           "l4_b1": [model_train.layer4[0], "layer4.0"], 
                           "l4_b2": [model_train.layer4[1], "layer4.1"]}

        layer_index_dict = {"inp_layer":0,
                           "l1_b1": 1, 
                           "l1_b2": 2, 
                           "l2_b1": 3, 
                           "l2_b2": 4, 
                           "l3_b1": 5,
                           "l3_b2": 6,
                           "l4_b1": 7, 
                           "l4_b2": 8}

        layer_out_shapes = {0: [64,8], 1: [64,8], 2: [64,8],
                            3: [128,4], 4: [128,4],
                            5: [256,2], 6: [256,2],
                            7: [512,1], 8: [512,1]}
        
    elif args.model == "resnet50":
        model_part_dict = {"inp_layer":[model_train.maxpool, "conv1"],
                           "l1_b1": [model_train.layer1[0], "layer1.0"], 
                           "l1_b2": [model_train.layer1[1], "layer1.1"],
                           "l1_b3": [model_train.layer1[2], "layer1.2"],
                           
                           "l2_b1": [model_train.layer2[0], "layer2.0"], 
                           "l2_b2": [model_train.layer2[1], "layer2.1"], 
                           "l2_b3": [model_train.layer2[2], "layer2.2"], 
                           "l2_b4": [model_train.layer2[3], "layer2.3"], 
                           
                           "l3_b1": [model_train.layer3[0], "layer3.0"], 
                           "l3_b2": [model_train.layer3[1], "layer3.1"], 
                           "l3_b3": [model_train.layer3[2], "layer3.2"], 
                           "l3_b4": [model_train.layer3[3], "layer3.3"], 
                           "l3_b5": [model_train.layer3[4], "layer3.4"], 
                           "l3_b6": [model_train.layer3[5], "layer3.5"], 
                           
                           "l4_b1": [model_train.layer4[0], "layer4.0"], 
                           "l4_b2": [model_train.layer4[1], "layer4.1"],
                           "l4_b3": [model_train.layer4[2], "layer4.2"]}

        layer_index_dict = {"inp_layer":0,
                           "l1_b1": 1, 
                           "l1_b2": 2, 
                           "l1_b3": 3, 
                            
                           "l2_b1": 4, 
                           "l2_b2": 5, 
                           "l2_b3": 6, 
                           "l2_b4": 7, 
                            
                           "l3_b1": 8,
                           "l3_b2": 9,
                           "l3_b3": 10,
                           "l3_b4": 11,
                           "l3_b5": 12,
                           "l3_b6": 13,
                            
                           "l4_b1": 14, 
                           "l4_b2": 15,
                           "l4_b3": 16,}

        layer_out_shapes = {0: [64,16], 1: [256,8], 2: [256,8], 3: [256,8],
                            4: [512,4], 5: [512,4], 6: [512,4], 7: [512,4],
                            8: [1024,2], 9: [1024,2], 10: [1024,2], 11: [1024,2], 12: [1024,2], 13: [1024,2],
                            14: [2048,1], 15: [2048,1], 16: [2048,1]}
    
    # Now we need to set model layers of interest to require grad
    int_layers = args.int_layers
    if len(int_layers) == 0:
        args.layer_to_change = [] # adding linear classifier's after which layers
        layer_hook_list = [] # Added 
    else:
        print(len(int_layers))
        print(f"Fine tune was done on layers {int_layers}")
        args.layer_to_change = [] # adding linear classifier's after which layers
        layer_hook_list = [] # Added 
        for layer_name in int_layers:
            args.layer_to_change.append(layer_index_dict[layer_name])
            layer_hook_list.append(ForwardHook(model_part_dict[layer_name][0]))
        print(args.layer_to_change)
        
    if args.after_ft:
        print(f"Load model from {args.load_path}/model_best.pth")
        checkpoint = torch.load(args.load_path + 'model_best' + '.pth', map_location=device)
        model_ft.load_state_dict(checkpoint['state_dict_ft'])
    else:
        print(f"Load model from {args.load_path}/model_epoch_{args.load_epoch}.pth")
        checkpoint = torch.load(args.load_path + 'model_epoch_' + str(args.load_epoch) + '.pth', map_location=device)
    model_train.load_state_dict(checkpoint["state_dict"])
    
    model_train.eval()
    model_ft.eval()
    
    info_dict = {
                 'collapse_metric': [],
                 'collapse_metric_test': [],
                 'ETF_feature_metric': [],
                 'ETF_feature_metric_test': [],
                 'cdnv': [],
                 'cdnv_test': [],
                 }
    
    model_train.eval()

    mu_G_train, mu_c_dict_train, before_class_dict_train = compute_info(args, [model_train, model_ft], layer_hook_list, trainloader)
    mu_G_test, mu_c_dict_test, before_class_dict_test = compute_info(args, [model_train, model_ft], layer_hook_list, testloader)
    
    cdnv = CDNV(before_class_dict_train)
    info_dict['cdnv'].append(cdnv)
    cdnv_test = CDNV(before_class_dict_test)
    info_dict['cdnv_test'].append(cdnv_test)

    Sigma_W = compute_Sigma_W(args, before_class_dict_train, mu_c_dict_train, batchsize=args.batch_size)
    # Sigma_W_test_norm = compute_Sigma_W(args, model, fc_features, mu_c_dict_train, testloader, isTrain=False)
    Sigma_W_test = compute_Sigma_W(args, before_class_dict_test, mu_c_dict_test, batchsize=args.batch_size)
    Sigma_B = compute_Sigma_B(mu_c_dict_train, mu_G_train)
    Sigma_B_test = compute_Sigma_B(mu_c_dict_test, mu_G_test)
    
    #print(Sigma_W, np.max(Sigma_W), np.min(Sigma_W))
    #print(scilin.pinv(Sigma_B), np.max(scilin.pinv(Sigma_B)), np.min(scilin.pinv(Sigma_B)))
    collapse_metric = np.trace(Sigma_W @ scilin.pinv(Sigma_B)) / len(mu_c_dict_train)
    collapse_metric_test = np.trace(Sigma_W_test @ scilin.pinv(Sigma_B_test)) / len(mu_c_dict_test)
    print("NC1", collapse_metric)
    print("NC1_test", collapse_metric_test)
    
    info_dict['collapse_metric'].append(collapse_metric)
    info_dict['collapse_metric_test'].append(collapse_metric_test)
    
    save_path = args.load_path + "/"
    with open(save_path + "info_penulti_" + args.dataset + ".pkl", 'wb') as f: 
        pickle.dump(info_dict, f)

if __name__ == "__main__":
    main()