import sys
import pickle
import torchvision
import torchvision.transforms as transforms
import torch
import random
import argparse
import os
import numpy as np
from torchvision.datasets import CIFAR10, MNIST, CIFAR100#, FGVCAircraft
from model_to_ft import VisionTransformer
from config import get_transfer_config
from define_data import get_data
from checkpoint import load_checkpoint
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class TransferModel(nn.Module):
    def __init__(self, inp_ch, num_classes):
        super(TransferModel, self).__init__()
        self.classifier = nn.Linear(inp_ch, num_classes, bias = True)
        
    def forward(self, x):
        x = self.classifier(x)
        return x
    
def train_epoch(config, FModel, layer_fc, dataloader, optimizer, scheduler, criterion, cur_epoch):
    
    FModel.train()
    # One epoch
    top1 = AverageMeter()
    top5 = AverageMeter()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(config.device), targets.to(config.device)
        
        outputs, all_layer_out = FModel(inputs)
        
        # We only need loss on the additional fc, but we do pooling first
        layer_idx = config.int_layer
        layer_out = all_layer_out[int(layer_idx)] # bs * image size * c 
        layer_out = layer_out.transpose(1,2).unsqueeze(2) # bs * c * 1 * image size
        layer_out = F.adaptive_avg_pool2d(layer_out, (1,1)) # bs * c * 1 * 1
        layer_out = torch.flatten(layer_out, 1)
        # Get logits from this layer's out
        logits_cur = layer_fc(layer_out)
        loss = criterion(logits_cur, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        prec1, prec5 = compute_accuracy(logits_cur.data, targets.data, topk=(1, 5))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # print statistics
        running_loss += loss.item()
        if batch_idx % 100 == 0:    # print every 100 mini-batches
            print(f'[Epoch {cur_epoch}, {batch_idx + 1:5d}] loss: {running_loss / 100 / config.batch_size:.3f}, acc: {top1.avg}')
            running_loss = 0.0
        
    return top1.avg, top5.avg

def evaluate_epoch(config, FModel, layer_fc, dataloader, criterion):
    
    FModel.eval()
    # One epoch
    test_top1 = AverageMeter()
    test_top5 = AverageMeter()
    running_loss = 0.0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):

            inputs, targets = inputs.to(config.device), targets.to(config.device)
            
            outputs, all_layer_out = FModel(inputs)

            layer_idx = config.int_layer
            layer_out = all_layer_out[int(layer_idx)] # bs * image size * c 
            layer_out = layer_out.transpose(1,2).unsqueeze(2) # bs * c * 1 * image size
            layer_out = F.adaptive_avg_pool2d(layer_out, (1,1)) # bs * c * 1 * 1
            layer_out = torch.flatten(layer_out, 1)
            # Get logits from this layer's out
            logits_cur = layer_fc(layer_out)

            prec1, prec5 = compute_accuracy(logits_cur.data, targets.data, topk=(1, 5))
            test_top1.update(prec1.item(), inputs.size(0))
            test_top5.update(prec5.item(), inputs.size(0))

    return test_top1.avg, test_top5.avg

def compute_accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def set_seed(manualSeed=666):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)
    

def main():
    config = get_transfer_config()

    if config.checkpoint_path is None:
        sys.exit('Need to input the path to a pre-trained model!')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    config.device = device

    set_seed(config.seed)
    
    # Dataset part
    print(f"Using dataset {config.dataset}")
    print()
    trainloader, testloader, num_classes = get_data(config.dataset, config.data_dir, config.image_size, config.batch_size, do_transform = True)
    
    # create model
    print("create model")
    model = VisionTransformer(
             image_size=(config.image_size, config.image_size),
             patch_size=(config.patch_size, config.patch_size),
             emb_dim=config.emb_dim,
             mlp_dim=config.mlp_dim,
             num_heads=config.num_heads,
             num_layers=config.num_layers,
             num_classes=num_classes,
             attn_dropout_rate=config.attn_dropout_rate,
             dropout_rate=config.dropout_rate)

    # load checkpoint
    if config.checkpoint_path:
        state_dict = load_checkpoint(config.checkpoint_path + config.pretrain_model_name)
        if num_classes != state_dict['classifier.weight'].size(0):
            del state_dict['classifier.weight']
            del state_dict['classifier.bias']
            print("re-initialize fc layer")
            model.load_state_dict(state_dict, strict=False)
        else:
            model.load_state_dict(state_dict)
        print(f"Load pretrained weights from {config.checkpoint_path}/{config.pretrain_model_name}")

    # send model to device
    model = model.to(device)
    model.eval()
    model.requires_grad = False
    
    num_layer_relation = {0: "transformer.encoder_layers.0",
                          1: "transformer.encoder_layers.1",
                          2: "transformer.encoder_layers.2",
                          3: "transformer.encoder_layers.3",
                          4: "transformer.encoder_layers.4",
                          5: "transformer.encoder_layers.5",
                          6: "transformer.encoder_layers.6",
                          7: "transformer.encoder_layers.7",
                          8: "transformer.encoder_layers.8",
                          9: "transformer.encoder_layers.9",
                          10: "transformer.encoder_layers.10",
                          11: "transformer.encoder_layers.11",
                          12: "classifier"}
    
    # Now we find which layers should we do transfer learning   
    int_layer = config.int_layer
    print(f"We do transfer learning on layer {int_layer}'s output")
    
    layer_dim = 768
    layer_fc = TransferModel(layer_dim, num_classes).to(device)
    
    print(f"Trainable classifiers: {len(list(layer_fc.parameters()))}")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        params=layer_fc.parameters(),
        lr=config.lr,
        weight_decay=config.wd,
        momentum=0.9)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=config.lr,
        pct_start=config.warmup_steps / config.train_steps,
        total_steps=config.train_steps)

    epochs = config.train_steps // len(trainloader)
    best_test_acc = 0.0
    train_accs = []
    test_accs = []
    is_best = False
    save_name = str(config.int_layer) # the folder to save fine tune results
    checkpoint_dir = config.checkpoint_path + "/" + config.dataset + "_nc/" + f"transfer_{save_name}/"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)
        
    for epoch in range(epochs):
        train_top1, train_top5 = train_epoch(config, model, layer_fc, trainloader, optimizer, scheduler, criterion, epoch)
        test_top1, test_top5 = evaluate_epoch(config, model, layer_fc, testloader, criterion)
        train_accs.append(train_top1)
        test_accs.append(test_top1)

        if test_top1 > best_test_acc:
            is_best = True
            best_test_acc = test_top1
        else:
            is_best = False
        print(f"Finish Epoch {epoch}, Training Acc: {train_top1}, Test Acc: {test_top1}, Current best Test Acc: {best_test_acc}, current LR: {scheduler.get_lr()}")
        
        state = {
                    'arch': config.model_arch,
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'test_acc': test_top1
                }
        if is_best:
            print("Save current model (best)")
            path = checkpoint_dir + 'model_best.pth'
            torch.save(state, path)
    
    print(f"Training Accs are: {train_accs}")
    print(f"Test Accs are: {test_accs}")

if __name__ == "__main__":
    main()