import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
import torch.backends.cudnn as cudnn
import random
import time
import numpy as np
from tqdm import tqdm
import json
import copy
from datetime import datetime

from ncp_acl.builder import get_base_model
from utils import AverageMeter

parser = argparse.ArgumentParser(
    description='Linear Finetuning (SLF)')

parser.add_argument('--backbone', type=str, default='resnet18')
parser.add_argument('--teacher_method', type=str, default='SimCLR')
parser.add_argument('--zero_init_residual', type=bool, default=True)

parser.add_argument('--data_dir', type=str)
parser.add_argument('--dataset', type=str)

parser.add_argument('--batch_size', type=int, default=512,
                    help='input batch size for training (default: 512)')
parser.add_argument('--epochs', type=int, default=25,
                    help='number of epochs to train')
parser.add_argument('--weight-decay', '--wd', default=2e-4,
                    type=float, metavar='W')
parser.add_argument('--lr', type=float, default=1.0,
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9,
                    help='SGD momentum')

parser.add_argument('--decreasing_lr', default='15,20',
                    help='decreasing strategy')

parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--multi_bn', action='store_true', default=True)

parser.add_argument('--use_normalize', action='store_true', default=True)
parser.add_argument('--use_amp', action='store_true', default=True)

parser.add_argument('--pretrained_lr', type=float, default=0.5)
parser.add_argument('--pretrained_epoch', type=int, default=100)

parser.add_argument('--seed', type=int, default=0)

parser.add_argument('--checkpoint_path', type=str)

CROP_SIZE_PER_DATASET = {
    'cifar10': 32,
    'cifar100': 32,
    'stl10': 96,
}

def fix_seed(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main():
    args = parser.parse_args()
    
    fix_seed(args.seed)
            
    encoder_path = args.checkpoint_path
    
    assert os.path.isfile(encoder_path), 'You should input valid checkpoint path!!!'
    assert args.dataset in ['cifar10', 'cifar100', 'stl10'], "Only CIFAR10, CIFAR100, and STL10 Datasets are supported"

    save_dir = os.path.join(os.path.dirname(encoder_path),
                            f'lr_{str(args.lr).replace(".", "_")}_batch_size_{str(args.batch_size).replace(".", "_")}')
    
    os.makedirs(save_dir, exist_ok=False)

    with open(os.path.join(save_dir, f'SLF_args_{datetime.now().strftime("%Y_%m_%d")}.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)    

    train_loader, num_classes = set_loader(args)
    
    checkpoint = torch.load(encoder_path, map_location='cpu')
            
    if args.backbone == 'resnet18':
        encoder = get_base_model(args.dataset,
                                 args.backbone,
                                 hidden_dim=num_classes,
                                 use_normalize=args.use_normalize,
                                 zero_init_residual=args.zero_init_residual,
                                 use_projector=False,
                                 multi_bn=False)
    else:
        raise NotImplementError
    
    if args.multi_bn:
        state_dict = cvt_state_dict(checkpoint['state_dict'], use_pgd_bn=True)
    else:
        state_dict = checkpoint['state_dict']
    
    msg = encoder.load_state_dict(state_dict, strict=True)
    print(f'load_state_dict result: {msg}')

    for name, param in encoder.named_parameters():
        param.requires_grad = False
    parameters = list(filter(lambda p: p.requires_grad, encoder.parameters()))
    assert len(parameters) == 0
    
    classifier = nn.Linear(512, num_classes)

    criterion = nn.CrossEntropyLoss().to(args.device)
    params = classifier.parameters()
    optimizer = optim.SGD(params, lr=args.lr,
                          momentum=args.momentum, 
                          weight_decay=args.weight_decay)

    
    decreasing_lr = list(map(int, args.decreasing_lr.split(',')))
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=decreasing_lr, gamma=0.1)

    start_epoch = 0

    for epoch in range(start_epoch, args.epochs):
        print("current lr is {}".format(
            optimizer.state_dict()['param_groups'][0]['lr']))

        train(args, encoder, classifier, args.device, train_loader, optimizer, criterion, epoch)
        scheduler.step()
        
        if epoch == 0 :
            sanity_check(encoder.state_dict(), state_dict)

    
    # save checkpoint
    torch.save({
        'epoch': epoch,
        'state_dict': classifier.state_dict(),
        'optim': optimizer.state_dict(),
    }, os.path.join(save_dir,
                    f'slf_checkpoint{epoch:04d}.pth.tar'))
    
    
def set_loader(args):
    crop_size = CROP_SIZE_PER_DATASET[args.dataset]
    
    augmentation = [transforms.RandomCrop(crop_size, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),]    
    transform_train = transforms.Compose(augmentation)
    
    if args.dataset in ['cifar10', 'cifar100']:
        train_datasets = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir, 
                                                                          train=True,
                                                                          download=True,
                                                                          transform=transform_train)
        
    else:
        train_datasets = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir, 
                                                                          split='train',
                                                                          download=True,
                                                                          transform=transform_train)
        
    num_classes = 100 if args.dataset == 'cifar100' else 10

    train_loader = torch.utils.data.DataLoader(train_datasets,
                                               batch_size=args.batch_size, 
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True,
                                               drop_last=False)
    
    return train_loader, num_classes

def train(args, encoder, classifier, device, train_loader, optimizer, criterion, epoch):        
    parameters = list(filter(lambda p: p.requires_grad, encoder.parameters()))
    assert len(parameters) == 0  # fc.weight, fc.bias

    encoder.to(device)
    classifier.to(device)
    
    encoder.eval()
    classifier.train()

    losses = AverageMeter()
    losses.reset()
    
    train_acc = 0
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)

        with torch.no_grad():
            features = encoder.eval()(data)
            
        output = classifier.train()(features.detach())
        loss = criterion(output, target)

        train_acc += (output.detach().argmax(dim=-1) == target).float().sum().cpu()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(float(loss.detach().cpu()), data.shape[0])
        
        if batch_idx % 10 == 0:
            print(f'Epoch:{epoch}) [{batch_idx}/{len(train_loader)}] \t' 
                  f'Average Loss: {losses.avg}')
    
    print(f'Epoch: {epoch}) Standard Accuracy (Train Data): {train_acc / len(train_loader.dataset) * 100}')
    
def cvt_state_dict(state_dict, use_pgd_bn=True):
    state_dict_new = {}
    
    if use_pgd_bn:
        bn_index = 1
    else:
        bn_index = 0

    for name in state_dict.keys():
        if 'bn' in name:
            assert 'bn_list' in name, 'state_dict does not consist of the multi batch normalization layer'
            if f'.bn_list.{bn_index}' in name:
                state_dict_new[name.replace(f'.bn_list.{bn_index}', '')] = state_dict[name]
        else:
            state_dict_new[name] = state_dict[name]  

    keys = list(state_dict_new.keys())
    for key in keys:
        if 'downsample.conv' in key:
            state_dict_new[key.replace('downsample.conv', 'downsample.0')] = state_dict_new[key]
            del state_dict_new[key]
        elif 'downsample.bn' in key:
            state_dict_new[key.replace('downsample.bn', 'downsample.1')] = state_dict_new[key]
            del state_dict_new[key]

    return state_dict_new
    
def sanity_check(state_dict, initial_state_dict):
    print("=> sanity check for teacher model")

    for k in list(state_dict.keys()):
        assert ((state_dict[k].cpu() == initial_state_dict[k]).all()), \
            '{} is changed in training.'.format(k)

    print("=> sanity check passed.")

if __name__ == '__main__':
    main()
