import sys
import pickle
import torchvision
import torchvision.transforms as transforms
import torch
import scipy.linalg as scilin
import random
import models
#from models.resnet import ResNet18
from models.sup_con_original import SupConResNet, SupCEResNet, LinearClassifier
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from define_data import get_data
from data_loader.mini_imagenet import MiniImagenet
from utils import load_from_state_dict,load_from_state_dict_without_fc
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class TransferModel(nn.Module):
    def __init__(self, inp_ch, num_classes):
        super(TransferModel, self).__init__()
        self.classifier = nn.Linear(inp_ch, num_classes, bias = True)
        #self.classifier = nn.Linear(inp_ch, num_classes, bias = False) # Removing bias
        
    def forward(self, x):
        # Normalize weight
        #self.classifier.weight.data = F.normalize(self.classifier.weight.data, dim = 1)
        # Normalize feature
        #x = F.normalize(x, dim=1) # Normalize Feature (exists before 07/13)
        x = self.classifier(x)
        return x
    
def train_transfer(args, FModel, TModel, dataloader, optimizer, criterion, cur_epoch):
    
    TModel.train()
    for param in FModel.parameters(): # doesn't let gradient flow back to the feature extractor model
        param.requires_grad = False
    # One epoch
    top1 = AverageMeter()
    top5 = AverageMeter()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(args.device), targets.to(args.device)
        
        if "mnist" in args.dataset:
            inputs = inputs.repeat(1,3,1,1)
            
        with torch.no_grad():
            if args.val_choice == "feature":
                features_cur = FModel.encoder(inputs)
            elif args.val_choice == "all":
                features_cur = FModel(inputs)
            else:
                raise ValueError("The val_choice type is not supported")
        
        fea_inps = features_cur.view(len(targets), -1)
        #print(fea_inps, fea_inps.shape)
        
        outputs = TModel(fea_inps)
        
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        prec1, prec5 = compute_accuracy(outputs.data, targets.data, topk=(1, 5))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # print statistics
        #print(loss.item())
        running_loss += loss.item()
        if batch_idx % 100 == 0:    # print every 100 mini-batches
            print(f'[Epoch {cur_epoch}, {batch_idx + 1:5d}] loss: {running_loss / 100:.3f}, acc: {top1.avg}')
            running_loss = 0.0
        
    return top1.avg, top5.avg

def evaluate_transfer(args, FModel, TModel, dataloader, criterion):
    
    TModel.eval()
    # One epoch
    test_top1 = AverageMeter()
    test_top5 = AverageMeter()
    running_loss = 0.0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):

            inputs, targets = inputs.to(args.device), targets.to(args.device)
            
            if "mnist" in args.dataset:
                inputs = inputs.repeat(1,3,1,1)
            
            with torch.no_grad():
                if args.val_choice == "feature":
                    features_cur = FModel.encoder(inputs)
                elif args.val_choice == "all":
                    features_cur = FModel(inputs)
                else:
                    raise ValueError("The val_choice type is not supported")

            fea_inps = features_cur.view(len(targets), -1)
            
            outputs = TModel(fea_inps)

            loss = criterion(outputs, targets)

            prec1, prec5 = compute_accuracy(outputs.data, targets.data, topk=(1, 5))
            test_top1.update(prec1.item(), inputs.size(0))
            test_top5.update(prec5.item(), inputs.size(0))

        
    return test_top1.avg, test_top5.avg


def get_dim(args, model, dataloader):
    
    all_features = []
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(args.device), targets.to(args.device)
        
        if "mnist" in args.dataset:
            inputs = inputs.repeat(1,3,1,1)

        with torch.no_grad():
            if args.val_choice == "feature":
                features_cur = model.encoder(inputs)
            elif args.val_choice == "all":
                features_cur = model(inputs)
            else:
                raise ValueError("The val_choice type is not supported")
        
        classifier_dim = features_cur.view(len(targets), -1).shape[1]
        
        break
    
    return classifier_dim

    
def compute_accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def set_seed(manualSeed=666):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)
    

def parse_eval_args():
    parser = argparse.ArgumentParser()

    # parameters
    # Model Selection
    parser.add_argument('--model', type=str, default='resnet18')
    parser.add_argument('--model_name', type=str, default='resnet18') # Model type
    parser.add_argument('--head_type', type=str, default='mlp') # projection head type for this model
    parser.add_argument('--val_choice', type=str, default='feature') # validation which part of the model
    parser.add_argument('--no-bias', dest='bias', action='store_false')
    parser.add_argument('--ETF_fc', dest='ETF_fc', action='store_true')
    parser.add_argument('--fixdim', dest='fixdim', type=int, default=0)
    parser.add_argument('--SOTA', dest='SOTA', action='store_true')
    parser.add_argument("--rm", dest='remove_last_relu', action="store_true")
    
    # MLP settings (only when using mlp and res_adapt(in which case only width has effect))
    parser.add_argument('--width', type=int, default=1024)
    parser.add_argument('--depth', type=int, default=6)

    # Hardware Setting
    parser.add_argument('--gpu_id', type=int, default=0)
    parser.add_argument('--seed', type=int, default=6)

    # Directory Setting
    # Note dataset mean which dataset to transfer learning on
    parser.add_argument('--dataset', type=str, choices=['mnist', 'cifar10', 'fashionmnist', 'miniimagenet','cifar100', 'aircraft', 'dtd', 'pet'], default='cifar10')
    parser.add_argument('--data_dir', type=str, default='/scratch/qingqu_root/qingqu1/xlxiao/DL/data')
    parser.add_argument('--load_path', type=str, default=None)

    # Learning Options
    parser.add_argument('--epochs', type=int, default=200, help='Max Epochs')
    parser.add_argument('--load_epoch', type=int, default=199, help='Checkpoint load epoch')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--sample_size', type=int, default=None, help='sample size PER CLASS')

    args = parser.parse_args()

    return args


def main():
    args = parse_eval_args()

    if args.load_path is None:
        sys.exit('Need to input the path to a pre-trained model!')

    #device = torch.device("cuda:"+str(args.gpu_id) if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.device = device

    set_seed(args.seed)
    
    # Dataset part
    print(f"Using dataset {args.dataset}")
    print()
    trainloader, testloader, num_classes = get_data(args.dataset, args.data_dir, args.batch_size, do_transform = True)
    
    if args.model == "supce":
        model_train = SupCEResNet(name=args.model_name, head=args.head_type, remove_last_relu=args.remove_last_relu).to(device)
    else:
        raise ValueError(f"Model type {args.model} not supported")
    
    model_ft = LinearClassifier(name='resnet50', num_classes=num_classes).to(device)
    
    print(f"Features getting are from {args.val_choice}")
    if args.load_path == "random":
        print("Use random initialized model")
    else:
        print(f"Use model load from checkpoint epoch {args.load_path}/{args.load_epoch}")
        checkpoint = torch.load(args.load_path + 'model_epoch_' + str(args.load_epoch) + '.pth', map_location=device)
        model_train.load_state_dict(checkpoint['state_dict'])
        model_train.eval()
    model_train.eval()
    
    # Get the correct classifier dimension
    print("Get Dimension")
    classifier_size = get_dim(args, model_train, trainloader)
    print(classifier_size)
    
    # Now we consider the transfer learning
    model_train.requires_grad = False
    TModel = TransferModel(classifier_size, num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.Adam(TModel.parameters(), lr=0.01, weight_decay=1e-4)
    # optimizer = optim.SGD(TModel.parameters(), lr=5, momentum = 0.9, weight_decay=0)
    # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [60, 75, 90], gamma=0.2)
    
    optimizer = optim.SGD(TModel.parameters(), lr=0.1, momentum = 0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.epochs, eta_min=0.001)
    
    epochs = args.epochs
    best_test_acc = 0.0
    train_accs = []
    test_accs = []
    for epoch in range(epochs):
        train_top1, train_top5 = train_transfer(args, model_train, TModel, trainloader, optimizer, criterion, epoch)
        test_top1, test_top5 = evaluate_transfer(args, model_train, TModel, testloader, criterion)
        train_accs.append(train_top1)
        test_accs.append(test_top1)
        scheduler.step()

        if test_top1 > best_test_acc:
            best_test_acc = test_top1
        print(f"Finish Epoch {epoch}, Training Acc: {train_top1}, Test Acc: {test_top1}, Current best Test Acc: {best_test_acc}, current LR: {scheduler.get_lr()}")
    
    print(f"Training Accs are: {train_accs}")
    print(f"Test Accs are: {test_accs}")
    
    info_dict = {"load_path": args.load_path,
                 "head_type": args.head_type,
                 "transfer_layer": args.val_choice,
                 "best_test": best_test_acc,
                 "final_train": train_top1,
                 }
    
    #with open(f"{args.load_path}/transfer_{args.head_type}_{args.val_choice}_{args.load_epoch}.pkl", 'wb') as f: 
    with open(f"{args.load_path}/transfer_{args.dataset}.pkl", 'wb') as f: 
        pickle.dump(info_dict, f)



if __name__ == "__main__":
    main()