import argparse
import shutil
import os
import time
import timm
import torch
import numpy as np
import logging
import torch.nn as nn
from torch import autocast
from torch.cuda.amp import GradScaler
import torch.optim
# from models.resnetv2 import KNOWN_MODELS

from utils.supcon import SupConLoss
from utils import data_loaders
from utils.utils import seed_all, freeze_except_fc, get_per_class_accuracy
from utils.vpt import build_promptmodel
from torchvision.models.vision_transformer import vit_b_16, vit_b_32, ViT_B_32_Weights, ViT_B_16_Weights
from torchvision.models.resnet import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights


def mixup_data(x, y, l):
  """Returns mixed inputs, pairs of targets, and lambda"""
  indices = torch.randperm(x.shape[0]).to(x.device)

  mixed_x = l * x + (1 - l) * x[indices]
  y_a, y_b = y, y[indices]
  return mixed_x, y_a, y_b


def mixup_criterion(criterion, pred, y_a, y_b, l):
  return l * criterion(pred, y_a) + (1 - l) * criterion(pred, y_b)


def zero_fc(model, model_name):
    if model_name == 'res50x3':
        nn.init.zeros_(model.head.conv.weight)
        nn.init.zeros_(model.head.conv.bias)
    elif model_name == 'vitb16' or model_name == 'vitl16':
        nn.init.zeros_(model.head.weight)
        nn.init.zeros_(model.head.bias)
    elif model_name == 'swin':
        nn.init.zeros_(model.head.fc.weight)
        nn.init.zeros_(model.head.fc.bias)
    else:
        nn.init.zeros_(model.fc.weight)
        nn.init.zeros_(model.fc.bias)


def load_moco(model, path='pretrained/r-50-1000ep.pth.tar'):
    print("=> loading checkpoint '{}'".format(path))
    checkpoint = torch.load(path, map_location="cpu")

    # rename moco pre-trained keys
    linear_keyword = 'fc'
    state_dict = checkpoint["state_dict"]
    for k in list(state_dict.keys()):
        # retain only base_encoder up to before the embedding layer
        if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.%s' % linear_keyword):
            # remove prefix
            state_dict[k[len("module.base_encoder."):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]

    args.start_epoch = 0
    msg = model.load_state_dict(state_dict, strict=False)
    assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
    return model


@torch.no_grad()
def compute_leep(model, train_loader, z, y, n):
    model.eval()
    # Compute the joint and marginal distribution
    Pyz = torch.zeros((y, z)).to('cuda')
    Pz = torch.zeros((1, z)).to('cuda')
    Pz = torch.zeros((1, z)).to('cuda')
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        outputs_prob = model(images).softmax(dim=1)
        Pz = Pz + outputs_prob.sum(dim=0, keepdims=True)
        for i, l in enumerate(labels):
            Pyz[l] = Pyz[l] + outputs_prob[i]
    Pz = Pz / n; Pyz = Pyz / n
    Pyz_marginal = Pyz / Pz
    # compute LEEP
    leep = 0
    for images, labels in train_loader:
        images = images.to(device)
        outputs_prob = model(images).softmax(dim=1)
        likelihood = outputs_prob @ Pyz_marginal.T
        for i, l in enumerate(labels):
            leep = leep + torch.log(likelihood[i][l])
    leep = leep / n
    print('The LEEP score is {}'.format(leep))


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='PyTorch Temporal Efficient Training')
    parser.add_argument('-m', '--model', default='res18', type=str, metavar='N',
                        choices=['res18', 'res50', 'vitb16', 'vitb32', 'vptb16', 'vitl16'],
                        help='number of data loading workers (default: 10)')
    parser.add_argument('-d', '--dataset', default='flowers', type=str, metavar='N',
                        choices=['caltech101', 'cifar100', 'cifar100', 'flowers', 'pets', 'aircraft',
                                 'pascalvoc', 'dtd', 'foods', 'cars', 'cub', 'dogs', 'eurosat', 'sun397'],
                        help='number of data loading workers (default: 10)')
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 10)')
    parser.add_argument('--epochs', default=150, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--port', default='13330', type=str, help='number of total epochs to run')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--lr', '--learning_rate', default=0.001, type=float, metavar='LR',
                        help='initial learning rate', dest='lr')
    parser.add_argument('--seed', default=1000, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--fixfeature', default=False, type=bool, metavar='N',
                        help='if fine tune only the fc layer')
    parser.add_argument('--supcon', action='store_true',
                        help='if use supcon loss to train')
    parser.add_argument('--fewshot', default=None, type=int, metavar='N',
                        help='images per class')
    parser.add_argument("--batch", type=int, default=64,
                        help="Batch size.")
    parser.add_argument("--batch_split", type=int, default=1,
                        help="Number of batches to compute gradient on before updating weights.")
    parser.add_argument("--rep", type=int, default=1,
                        help="Number of repeats of experiments.")
    parser.add_argument("--fuse", type=list, default=0,
                        help="index of synthetic dataset.")
    parser.add_argument("--noval", action='store_true',
                        help="do not evaluate during the training, espeicially for few shot")
    parser.add_argument("--mixup", type=float, default=0.,
                        help="use mixup to train the model")
    parser.add_argument("--device-id", type=int, default=0,
                        help="gpu device id")
    args = parser.parse_args()

    accs = []

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)

    for r in range(args.rep):
        seed_all(args.seed + r)
        imgnet_norm = 'bit' if args.model == 'res50x3' else 'normal'
        train_dataset, val_dataset, num_cls = data_loaders.build_transfer_dataset(dataset=args.dataset, download=True,
                                                                                  imgnet_norm=imgnet_norm,
                                                                                  fuse=args.fuse, fewshot=args.fewshot,
                                                                                  data_config=None)
        if args.model == 'res50':
            model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
            model.fc = nn.Linear(512 * 4, num_cls)
            wd = 5e-4
            data_config = None
        elif args.model == 'vitl16':
            model = timm.create_model('vit_large_patch16_224.augreg_in21k_ft_in1k', pretrained=True)
            model.reset_classifier(num_cls)
            wd = 0
            data_config = timm.data.resolve_model_data_config(model)
        elif args.model == 'vitb16':
            model = timm.create_model('vit_base_patch16_224.augreg2_in21k_ft_in1k', pretrained=True)
            model.reset_classifier(num_cls)
            wd = 0
            data_config = timm.data.resolve_model_data_config(model)
        elif args.model == 'res18':
            model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
            model.fc = nn.Linear(512, num_cls)
            wd = 5e-4
            data_config = None
        else:
            raise NotImplementedError

        # model.load_state_dict(torch.load('pretrained/{}_{}_f1000.pt'.format(args.model, args.dataset), map_location='cpu'))
        zero_fc(model, args.model)
        model.cuda()

        print('Training dataset size: {}, validation dataset size: {}'.format(len(train_dataset), len(val_dataset)))

        micro_batch_size = args.batch // args.batch_split
        valid_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=micro_batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True, drop_last=False)

        if micro_batch_size <= len(train_dataset):
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=micro_batch_size, shuffle=True,
                num_workers=args.workers, pin_memory=True, drop_last=False)
        else:
            # In the few-shot cases, the total dataset size might be smaller than the batch-size.
            # In these cases, the default sampler doesn't repeat, so we need to make it do that
            # if we want to match the behaviour from the paper.
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True,
                sampler=torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=micro_batch_size))
        device = next(model.parameters()).device
        criterion = nn.CrossEntropyLoss().cuda() if args.dataset != 'pascalvoc' else nn.BCEWithLogitsLoss().cuda()

        # compute leep
        # compute_leep(model, train_loader, z=model.fc.bias.shape[0], y=num_cls, n=len(train_dataset))
        # exit()

        if args.fixfeature:
            group = freeze_except_fc(model)
            optimizer = torch.optim.SGD(group, lr=args.lr, weight_decay=wd)
        else:
            if args.model == 'vptb16':
                optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=wd)
            else:
                optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=wd)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.epochs)
        device = next(model.parameters()).device
        scaler = GradScaler()

        best_acc = 0
        best_epoch = 0
        facc = 0
        train_accs = []
        for epoch in range(args.epochs):
            running_loss = 0
            start_time = time.time()
            if args.fixfeature:
                model.eval()
            else:
                model.train()
            M = len(train_loader)
            total = 0
            correct = 0
            for i, (images, labels) in enumerate(train_loader):
                optimizer.zero_grad()
                labels = labels.to(device)
                images = images.to(device)

                with autocast(device_type='cuda', dtype=torch.float16):
                    if args.mixup > 0.0:
                        mixup_l = np.random.beta(args.mixup, args.mixup)
                        images, y_a, y_b = mixup_data(images, labels, mixup_l)
                        outputs = model(images)
                        loss = mixup_criterion(criterion, outputs, y_a, y_b, mixup_l)
                    else:
                        outputs = model(images)
                        loss = criterion(outputs, labels)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                running_loss += loss.item()
                total += float(labels.size(0))
                if args.dataset == 'pascalvoc':
                    correct += get_per_class_accuracy(args, val_dataset)(outputs.cpu(), labels)
                else:
                    _, predicted = outputs.cpu().max(1)
                    correct += float(predicted.eq(labels.cpu()).sum().item())
            print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Time elapsed: %.2f, Accuracy: %.2f'
                        % (epoch + 1, args.epochs, i + 1, M // 1, running_loss/(i+1), time.time() - start_time,
                           100 * correct / total))
            train_accs += [100 * correct / total]

            scheduler.step()
            correct = 0
            total = 0
            model.eval()

            if ((args.fewshot is not None or args.noval) and epoch < args.epochs-1) or args.supcon:
                final_acc = 0.
            else:
                with torch.no_grad():
                    for batch_idx, (inputs, targets) in enumerate(valid_loader):
                        inputs = inputs.to(device)
                        outputs = model(inputs)
                        mean_out = outputs.cpu()
                        total += float(targets.size(0))
                        correct += get_per_class_accuracy(args, val_dataset)(mean_out, targets)

                    final_acc = (100 * correct / total)
                    print('Test Accuracy of the model on test images: %.3f' % final_acc)

            facc = final_acc

            if best_acc < facc:
                best_acc = facc
                best_epoch = epoch + 1
                print('best_acc is: {}, find in epoch: {}'.format(best_acc, best_epoch))
                print('\n')
        accs += [facc]
        # print([round(acc, 2) for acc in train_accs])
        if isinstance(args.fuse, list):
            if args.mixup > 0.:
                torch.save(model.state_dict(), 'pretrained/{}_{}_f{}_mixup.pt'.format(args.model, args.dataset, sum(args.fuse)))
            else:
                torch.save(model.state_dict(), 'pretrained/{}_{}_f{}.pt'.format(args.model, args.dataset, sum(args.fuse)))
        # print(model.fc.weight.data.norm() / num_cls)

    print('Final Accuracys: {}'.format([round(a, 1) for a in accs]))
