import sys
import pickle
import torchvision
import torchvision.transforms as transforms
import torch
import random
from torchvision.models import resnet18, resnet50
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from define_data import get_data
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(LinearClassifier, self).__init__()

        if name == 'resnet18':
            feat_dim = 512
        elif name == "resnet50":
            feat_dim = 2048
        else:
            raise ValueError("Please use ResNet18 or ResNet50!")
            
        self.linear = nn.Linear(feat_dim, num_classes, bias = True)

    def forward(self, features):
        return self.linear(features)
            
class ForwardHook():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.outputs = []
    def hook_fn(self, module, module_in, module_out):
        self.outputs.append(module_out)
    def clear(self):
        self.outputs = []
    def close(self):
        self.hook.remove()

    
def train_epoch(args, FModel, TModel, layer_hook_list, dataloader, optimizer, criterion, cur_epoch):
    penulti_length = 2048 if args.model == "resnet50" else 512
    #FModel.train()
    set_to_train(args, FModel)
    TModel.train()
    # One epoch
    top1 = AverageMeter()
    top5 = AverageMeter()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(args.device), targets.to(args.device)
        
        feature = FModel(inputs)
        
        all_skip_connect_sum = 0
        for i in range(len(args.layer_to_change)):
            penulti_length = 2048 if args.model == "resnet50" else 512
            layer_hook = layer_hook_list[i]
            idx_layer_out = layer_hook.outputs[0]
            # Do the skip connection
            pooled_layer_out = F.adaptive_avg_pool2d(idx_layer_out, (1,1)).squeeze()
            out_length = pooled_layer_out.shape[1] # Get length
            skip_connect_out = F.pad(pooled_layer_out, [0, penulti_length-out_length,0,0])
            all_skip_connect_sum += skip_connect_out 
            layer_hook.clear()

        feature_skip = (feature + all_skip_connect_sum) / (len(args.layer_to_change) + 1)
                
        outputs = TModel(feature_skip)
        
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        prec1, prec5 = compute_accuracy(outputs.data, targets.data, topk=(1, 5))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # print statistics
        running_loss += loss.item()
        if batch_idx % 100 == 0:    # print every 100 mini-batches
            print(f'[Epoch {cur_epoch}, {batch_idx + 1:5d}] loss: {running_loss / 100 / args.batch_size:.3f}, acc: {top1.avg}')
            running_loss = 0.0
        
    return top1.avg, top5.avg

def evaluate_epoch(args, FModel, TModel, layer_hook_list, dataloader, criterion):
    
    penulti_length = 2048 if args.model == "resnet50" else 512
    FModel.eval()
    TModel.eval()
    # One epoch
    test_top1 = AverageMeter()
    test_top5 = AverageMeter()
    running_loss = 0.0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):

            inputs, targets = inputs.to(args.device), targets.to(args.device)
            
            with torch.no_grad():
                feature = FModel(inputs)
                
                all_skip_connect_sum = 0
                for i in range(len(args.layer_to_change)):
                    layer_hook = layer_hook_list[i]
                    idx_layer_out = layer_hook.outputs[0]
                    # Do the skip connection
                    pooled_layer_out = F.adaptive_avg_pool2d(idx_layer_out, (1,1)).squeeze()
                    out_length = pooled_layer_out.shape[1] # Get length
                    skip_connect_out = F.pad(pooled_layer_out, [0, penulti_length-out_length,0,0])
                    all_skip_connect_sum += skip_connect_out 
                    layer_hook.clear()

                feature_skip = (feature + all_skip_connect_sum) / (len(args.layer_to_change) + 1)
        
                #outputs = TModel(feature)
                #feature_skip = F.relu(feature_skip)
                outputs = TModel(feature_skip)

            loss = criterion(outputs, targets)

            prec1, prec5 = compute_accuracy(outputs.data, targets.data, topk=(1, 5))
            test_top1.update(prec1.item(), inputs.size(0))
            test_top5.update(prec5.item(), inputs.size(0))

    return test_top1.avg, test_top5.avg

def set_to_train(args, model):
    if args.open_norm:
        model.train()
    else:
        pass
    
def compute_accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def set_seed(manualSeed=666):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)
    

def parse_eval_args():
    parser = argparse.ArgumentParser()

    # parameters
    # Model Selection
    parser.add_argument('--model', type=str, default='resnet18')
    # All layers that will be finetuned
    parser.add_argument('--int_layers', nargs='*', help="All layers that will be finetuned") 
    parser.add_argument("--open_norm", dest='open_norm', action="store_true")

    # Hardware Setting
    parser.add_argument('--gpu_id', type=int, default=0)
    parser.add_argument('--seed', type=int, default=6)

    # Directory Setting
    # Note dataset mean which dataset to transfer learning on
    parser.add_argument('--dataset', type=str, choices=['cifar10','cifar100','pet','dtd','aircraft'], default='cifar10')
    parser.add_argument('--data_dir', type=str, default='<path to folder where data should be saved>')
    parser.add_argument('--load_path', type=str, default=None)

    # Learning Options
    parser.add_argument('--epochs', type=int, default=200, help='Max Epochs')
    parser.add_argument('--load_epoch', type=int, default=499, help='Checkpoint load epoch')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--sample_size', type=int, default=None, help='sample size PER CLASS')

    args = parser.parse_args()

    return args

def main():
    args = parse_eval_args()

    if args.load_path is None:
        sys.exit('Need to input the path to a pre-trained model!')

    #device = torch.device("cuda:"+str(args.gpu_id) if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.device = device

    set_seed(args.seed)
    
    # Dataset part
    print(f"Using dataset {args.dataset}")
    print()
    trainloader, testloader, num_classes = get_data(args.dataset, args.data_dir, args.batch_size, do_transform = True)
    
    if args.model == "resnet18":
        model_train = resnet18(weights='IMAGENET1K_V1').to(device)
        model_train.fc = nn.Sequential()
    elif args.model == "resnet50":
        model_train = resnet50(weights='IMAGENET1K_V1').to(device)
        model_train.fc = nn.Sequential()
    else:
        raise ValueError(f"Model type {args.model} not supported")
    
    model_ft = LinearClassifier(name=args.model, num_classes=num_classes).to(device)
    
    model_train.eval()
    model_train.requires_grad = False
    
    if args.model == "resnet18":
        model_part_dict = {"inp_layer":[model_train.maxpool, "conv1"],
                           "l1_b1": [model_train.layer1[0], "layer1.0"], 
                           "l1_b2": [model_train.layer1[1], "layer1.1"],
                           "l2_b1": [model_train.layer2[0], "layer2.0"], 
                           "l2_b2": [model_train.layer2[1], "layer2.1"], 
                           "l3_b1": [model_train.layer3[0], "layer3.0"], 
                           "l3_b2": [model_train.layer3[1], "layer3.1"], 
                           "l4_b1": [model_train.layer4[0], "layer4.0"], 
                           "l4_b2": [model_train.layer4[1], "layer4.1"]}

        layer_index_dict = {"inp_layer":0,
                           "l1_b1": 1, 
                           "l1_b2": 2, 
                           "l2_b1": 3, 
                           "l2_b2": 4, 
                           "l3_b1": 5,
                           "l3_b2": 6,
                           "l4_b1": 7, 
                           "l4_b2": 8}

        layer_out_shapes = {0: [64,8], 1: [64,8], 2: [64,8],
                            3: [128,4], 4: [128,4],
                            5: [256,2], 6: [256,2],
                            7: [512,1], 8: [512,1]}
        
    elif args.model == "resnet50":
        model_part_dict = {"inp_layer":[model_train.maxpool, "conv1"],
                           "l1_b1": [model_train.layer1[0], "layer1.0"], 
                           "l1_b2": [model_train.layer1[1], "layer1.1"],
                           "l1_b3": [model_train.layer1[2], "layer1.2"],
                           
                           "l2_b1": [model_train.layer2[0], "layer2.0"], 
                           "l2_b2": [model_train.layer2[1], "layer2.1"], 
                           "l2_b3": [model_train.layer2[2], "layer2.2"], 
                           "l2_b4": [model_train.layer2[3], "layer2.3"], 
                           
                           "l3_b1": [model_train.layer3[0], "layer3.0"], 
                           "l3_b2": [model_train.layer3[1], "layer3.1"], 
                           "l3_b3": [model_train.layer3[2], "layer3.2"], 
                           "l3_b4": [model_train.layer3[3], "layer3.3"], 
                           "l3_b5": [model_train.layer3[4], "layer3.4"], 
                           "l3_b6": [model_train.layer3[5], "layer3.5"], 
                           
                           "l4_b1": [model_train.layer4[0], "layer4.0"], 
                           "l4_b2": [model_train.layer4[1], "layer4.1"],
                           "l4_b3": [model_train.layer4[2], "layer4.2"]}

        layer_index_dict = {"inp_layer":0,
                           "l1_b1": 1, 
                           "l1_b2": 2, 
                           "l1_b3": 3, 
                            
                           "l2_b1": 4, 
                           "l2_b2": 5, 
                           "l2_b3": 6, 
                           "l2_b4": 7, 
                            
                           "l3_b1": 8,
                           "l3_b2": 9,
                           "l3_b3": 10,
                           "l3_b4": 11,
                           "l3_b5": 12,
                           "l3_b6": 13,
                            
                           "l4_b1": 14, 
                           "l4_b2": 15,
                           "l4_b3": 16,}

        layer_out_shapes = {0: [64,16], 1: [256,8], 2: [256,8], 3: [256,8],
                            4: [512,4], 5: [512,4], 6: [512,4], 7: [512,4],
                            8: [1024,2], 9: [1024,2], 10: [1024,2], 11: [1024,2], 12: [1024,2], 13: [1024,2],
                            14: [2048,1], 15: [2048,1], 16: [2048,1]}
    
    print(f"The Batchnorm status is {args.open_norm}")
    # Now we need to set model layers of interest to require grad
    int_layers = args.int_layers
    if len(int_layers) == 0:
        args.layer_to_change = [] # adding linear classifier's after which layers
        layer_hook_list = [] # Added 
    else:
        print(len(int_layers))
        print(f"Fine tune for layers {int_layers}")
        args.layer_to_change = [] # adding linear classifier's after which layers
        layer_hook_list = [] # Added 
        for layer_name in int_layers:
            args.layer_to_change.append(layer_index_dict[layer_name])
            layer_hook_list.append(ForwardHook(model_part_dict[layer_name][0]))
        print(args.layer_to_change)
    
    need_param_list = [] # params that will be trained
        
    if len(int_layers) == 0:
        print("Transfer Learning")
    elif len(int_layers) == 1 and int_layers[0] == "all": # Fine tune the whole model
        print("Will fine tune the whole model")
        need_param_list = list(model_train.parameters())
    else:
        for int_layer in int_layers:
            layer_list = model_part_dict[int_layer][1]
            for key, mod in model_train.named_modules():
                if key != "" and key == layer_list:
                    print(key)
                    mod.requires_grad = True
                    need_param_list = need_param_list + list(mod.parameters())
    
    criterion = nn.CrossEntropyLoss()
    trainable_params = need_param_list + list(model_ft.parameters())
    print(f"Trainable layers: model_train: {len(need_param_list)}, model_ft: {len(list(model_ft.parameters()))}")
    optimizer = optim.SGD(trainable_params, lr=0.1, momentum = 0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.epochs, eta_min=0.0001)
    
    epochs = args.epochs
    best_test_acc = 0.0
    train_accs = []
    test_accs = []
    is_best = False 
    save_name = "-".join([str(i) for i in args.layer_to_change]) # the folder to save fine tune results
    
    checkpoint_dir = args.load_path + f"fine_tune_skippad(blockwise)_{args.dataset}_{save_name}_{args.open_norm}/"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)
        
    for epoch in range(epochs):
        train_top1, train_top5 = train_epoch(args, model_train, model_ft, layer_hook_list, trainloader, optimizer, criterion, epoch)
        test_top1, test_top5 = evaluate_epoch(args, model_train, model_ft, layer_hook_list, testloader, criterion)
        train_accs.append(train_top1)
        test_accs.append(test_top1)
        scheduler.step()

        if test_top1 > best_test_acc:
            is_best = True
            best_test_acc = test_top1
        else:
            is_best = False
        print(f"Finish Epoch {epoch}, Training Acc: {train_top1}, Test Acc: {test_top1}, Current best Test Acc: {best_test_acc}, current LR: {scheduler.get_lr()}")
        
        state = {
                    'arch': args.model,
                    'epoch': epoch,
                    'layer_fted': args.layer_to_change,
                    'state_dict': model_train.state_dict(),
                    'state_dict_ft': model_ft.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'test_acc': test_top1
                }
        if is_best:
            print("Save current model (best)")
            path = checkpoint_dir + 'model_best.pth'
            torch.save(state, path)
        elif (epoch+1) % 50 == 0:
            print("Save current model (epoch)")
            path = checkpoint_dir + f'model_epoch_{epoch}.pth'
            torch.save(state, path)
    
    print(f"Training Accs are: {train_accs}")
    print(f"Test Accs are: {test_accs}")

if __name__ == "__main__":
    main()