import sys
import torch
import random
from torchvision.models import resnet18, resnet50
import argparse
import os
import numpy as np
from define_data import get_data
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(LinearClassifier, self).__init__()

        if name == 'resnet18':
            feat_dim = 512
        elif name == "resnet50":
            feat_dim = 2048
        else:
            raise ValueError("Please use ResNet18 or ResNet50!")
            
        self.linear = nn.Linear(feat_dim, num_classes, bias = True)

    def forward(self, features):
        return self.linear(features)
    
def train_epoch(args, FModel, TModel, dataloader, optimizer, criterion, cur_epoch):
    
    #FModel.train()
    set_to_train(args, FModel)
    TModel.train()
    # One epoch
    top1 = AverageMeter()
    top5 = AverageMeter()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(args.device), targets.to(args.device)
        
        feature = FModel(inputs)
        outputs = TModel(feature)
        
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        prec1, prec5 = compute_accuracy(outputs.data, targets.data, topk=(1, 5))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # print statistics
        running_loss += loss.item()
        if batch_idx % 100 == 0:    # print every 100 mini-batches
            print(f'[Epoch {cur_epoch}, {batch_idx + 1:5d}] loss: {running_loss / 100 / args.batch_size:.3f}, acc: {top1.avg}')
            running_loss = 0.0
        
    return top1.avg, top5.avg

def evaluate_epoch(args, FModel, TModel, dataloader, criterion):
    
    FModel.eval()
    TModel.eval()
    # One epoch
    test_top1 = AverageMeter()
    test_top5 = AverageMeter()
    running_loss = 0.0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):

            inputs, targets = inputs.to(args.device), targets.to(args.device)
            
            with torch.no_grad():
                feature = FModel(inputs)
                outputs = TModel(feature)

            loss = criterion(outputs, targets)

            prec1, prec5 = compute_accuracy(outputs.data, targets.data, topk=(1, 5))
            test_top1.update(prec1.item(), inputs.size(0))
            test_top5.update(prec5.item(), inputs.size(0))

    return test_top1.avg, test_top5.avg

def set_to_train(args, model):
    if args.open_norm:
        model.train()
    else:
        pass
    
def compute_accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def set_seed(manualSeed=666):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)
    

def parse_eval_args():
    parser = argparse.ArgumentParser()

    # parameters
    # Model Selection
    parser.add_argument('--model', type=str, default='resnet18')
    parser.add_argument('--int_layers', nargs='*', help="All layers that will be finetuned") 
    parser.add_argument("--open_norm", dest='open_norm', action="store_true")

    # Hardware Setting
    parser.add_argument('--gpu_id', type=int, default=0)
    parser.add_argument('--seed', type=int, default=6)
    parser.add_argument('--lr', type=float, default=0.1, help='Learning rate')

    # Directory Setting
    # Note dataset mean which dataset to transfer learning on
    parser.add_argument('--dataset', type=str, choices=['cifar10','cifar100','pet','dtd','aircraft'], default='cifar10')
    parser.add_argument('--data_dir', type=str, default='<path to folder where data should be saved>')
    parser.add_argument('--load_path', type=str, default=None)

    # Learning Options
    parser.add_argument('--epochs', type=int, default=200, help='Max Epochs')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')

    args = parser.parse_args()

    return args

def main():
    args = parse_eval_args()
    
    if args.load_path is None:
        sys.exit('Need to input the path to a pre-trained model!')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.device = device
    print(device)
    set_seed(args.seed)
    
    # Dataset part
    print(f"Using dataset {args.dataset}")
    print()
    trainloader, testloader, num_classes = get_data(args.dataset, args.data_dir, args.batch_size, do_transform = True)
    
    
    if args.model == "resnet18":
        model_train = resnet18(weights='IMAGENET1K_V1').to(device)
        model_train.fc = nn.Sequential()
    elif args.model == "resnet50":
        model_train = resnet50(weights='IMAGENET1K_V1').to(device)
        model_train.fc = nn.Sequential()
    else:
        raise ValueError(f"Model type {args.model} not supported")
    
    model_ft = LinearClassifier(name=args.model, num_classes=num_classes).to(device)
    
    model_train.eval()
    model_train.requires_grad = False
    
    if args.model == "resnet18":
        model_part_dict = {"inp_layer":[model_train.maxpool, "conv1"],
                           "l1_b1": [model_train.layer1[0], "layer1.0"], 
                           "l1_b2": [model_train.layer1[1], "layer1.1"],
                           "l2_b1": [model_train.layer2[0], "layer2.0"], 
                           "l2_b2": [model_train.layer2[1], "layer2.1"], 
                           "l3_b1": [model_train.layer3[0], "layer3.0"], 
                           "l3_b2": [model_train.layer3[1], "layer3.1"], 
                           "l4_b1": [model_train.layer4[0], "layer4.0"], 
                           "l4_b2": [model_train.layer4[1], "layer4.1"]}

        layer_index_dict = {"inp_layer":0,
                           "l1_b1": 1, 
                           "l1_b2": 2, 
                           "l2_b1": 3, 
                           "l2_b2": 4, 
                           "l3_b1": 5,
                           "l3_b2": 6,
                           "l4_b1": 7, 
                           "l4_b2": 8}
        
    elif args.model == "resnet50":
        model_part_dict = {"inp_layer":[model_train.maxpool, "conv1"],
                           "l1_b1": [model_train.layer1[0], "layer1.0"], 
                           "l1_b2": [model_train.layer1[1], "layer1.1"],
                           "l1_b3": [model_train.layer1[2], "layer1.2"],
                           
                           "l2_b1": [model_train.layer2[0], "layer2.0"], 
                           "l2_b2": [model_train.layer2[1], "layer2.1"], 
                           "l2_b3": [model_train.layer2[2], "layer2.2"], 
                           "l2_b4": [model_train.layer2[3], "layer2.3"], 
                           
                           "l3_b1": [model_train.layer3[0], "layer3.0"], 
                           "l3_b2": [model_train.layer3[1], "layer3.1"], 
                           "l3_b3": [model_train.layer3[2], "layer3.2"], 
                           "l3_b4": [model_train.layer3[3], "layer3.3"], 
                           "l3_b5": [model_train.layer3[4], "layer3.4"], 
                           "l3_b6": [model_train.layer3[5], "layer3.5"], 
                           
                           "l4_b1": [model_train.layer4[0], "layer4.0"], 
                           "l4_b2": [model_train.layer4[1], "layer4.1"],
                           "l4_b3": [model_train.layer4[2], "layer4.2"]}

        layer_index_dict = {"inp_layer":0,
                           "l1_b1": 1, 
                           "l1_b2": 2, 
                           "l1_b3": 3, 
                            
                           "l2_b1": 4, 
                           "l2_b2": 5, 
                           "l2_b3": 6, 
                           "l2_b4": 7, 
                            
                           "l3_b1": 8,
                           "l3_b2": 9,
                           "l3_b3": 10,
                           "l3_b4": 11,
                           "l3_b5": 12,
                           "l3_b6": 13,
                            
                           "l4_b1": 14, 
                           "l4_b2": 15,
                           "l4_b3": 16}
    
    print(f"The Batchnorm status is {args.open_norm}")
    # Now we need to set model layers of interest to require grad
    int_layers = args.int_layers
    print(f"Fine tune for layers {int_layers}")
    args.layer_to_change = [] # adding linear classifier's after which layers
    if len(int_layers) == 0 or int_layers[0] == "all":
        pass
    else:
        for layer_name in int_layers:
            args.layer_to_change.append(layer_index_dict[layer_name])
    print(args.layer_to_change)
    
    need_param_list = [] # params that will be trained
    if len(int_layers) == 0:
        print("Transfer Learning")
    elif len(int_layers) == 1 and int_layers[0] == "all": # Fine tune the whole model
        print("Will fine tune the whole model")
        need_param_list = list(model_train.parameters())
    else:
        for int_layer in int_layers:
            layer_list = model_part_dict[int_layer][1]
            for key, mod in model_train.named_modules():
                if key != "" and key == layer_list:
                    print(key)
                    mod.requires_grad = True
                    need_param_list = need_param_list + list(mod.parameters())
    
    criterion = nn.CrossEntropyLoss()
    trainable_params = need_param_list + list(model_ft.parameters())
    print(f"Trainable layers: model_train: {len(need_param_list)}, model_ft: {len(list(model_ft.parameters()))}")
    optimizer = optim.SGD(trainable_params, lr=args.lr, momentum = 0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.epochs, eta_min=args.lr/1000)
    
    epochs = args.epochs
    best_test_acc = 0.0
    train_accs = []
    test_accs = []
    is_best = False 
    if len(int_layers) == 1 and int_layers[0] == "all":
        save_name = "all" # the folder to save fine tune results
    else:
        save_name = "-".join([str(i) for i in args.layer_to_change]) # the folder to save fine tune results
    
    checkpoint_dir = args.load_path + f"fine_tune(blockwise)_{args.dataset}_{save_name}_{args.open_norm}/"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)
        
    for epoch in range(epochs):
        train_top1, train_top5 = train_epoch(args, model_train, model_ft, trainloader, optimizer, criterion, epoch)
        test_top1, test_top5 = evaluate_epoch(args, model_train, model_ft, testloader, criterion)
        train_accs.append(train_top1)
        test_accs.append(test_top1)
        scheduler.step()

        if test_top1 > best_test_acc:
            is_best = True
            best_test_acc = test_top1
        else:
            is_best = False
        print(f"Finish Epoch {epoch}, Training Acc: {train_top1}, Test Acc: {test_top1}, Current best Test Acc: {best_test_acc}, current LR: {scheduler.get_lr()}")
        
        state = {
                    'arch': args.model,
                    'epoch': epoch,
                    'layer_fted': args.layer_to_change,
                    'state_dict': model_train.state_dict(),
                    'state_dict_ft': model_ft.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'test_acc': test_top1
                }
        if is_best:
            print("Save current model (best)")
            path = checkpoint_dir + 'model_best.pth'
            torch.save(state, path)
        elif (epoch+1) % 20 == 0:
            print("Save current model (epoch)")
            path = checkpoint_dir + f'model_epoch_{epoch}.pth'
            torch.save(state, path)
    
    print(f"Training Accs are: {train_accs}")
    print(f"Test Accs are: {test_accs}")

if __name__ == "__main__":
    main()
