import argparse
import shutil
import os
import time
import torch
import numpy as np
import logging
import torch.nn as nn
from torch import autocast
from torch.cuda.amp import GradScaler
import torch.optim
from models.resnetv2 import KNOWN_MODELS

from utils.supcon import SupConLoss
from utils import data_loaders
from utils.utils import seed_all, freeze_except_fc, get_per_class_accuracy
from torchvision.models.vision_transformer import vit_b_16, vit_b_32, ViT_B_32_Weights, ViT_B_16_Weights
from torchvision.models.resnet import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights


def mixup_data(x, y, l):
  """Returns mixed inputs, pairs of targets, and lambda"""
  indices = torch.randperm(x.shape[0]).to(x.device)

  mixed_x = l * x + (1 - l) * x[indices]
  y_a, y_b = y, y[indices]
  return mixed_x, y_a, y_b


def mixup_criterion(criterion, pred, y_a, y_b, l):
  return l * criterion(pred, y_a) + (1 - l) * criterion(pred, y_b)


def load_moco(model, path='pretrained/r-50-1000ep.pth.tar'):
    print("=> loading checkpoint '{}'".format(path))
    checkpoint = torch.load(path, map_location="cpu")

    # rename moco pre-trained keys
    linear_keyword = 'fc'
    state_dict = checkpoint["state_dict"]
    for k in list(state_dict.keys()):
        # retain only base_encoder up to before the embedding layer
        if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.%s' % linear_keyword):
            # remove prefix
            state_dict[k[len("module.base_encoder."):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]

    args.start_epoch = 0
    msg = model.load_state_dict(state_dict, strict=False)
    assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
    return model


@torch.no_grad()
def compute_leep(model, train_loader, z, y, n):
    model.eval()
    # Compute the joint and marginal distribution
    Pyz = torch.zeros((y, z)).to('cuda')
    Pz = torch.zeros((1, z)).to('cuda')
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        outputs_prob = model(images).softmax(dim=1)
        Pz = Pz + outputs_prob.sum(dim=0, keepdims=True)
        for i, l in enumerate(labels):
            Pyz[l] = Pyz[l] + outputs_prob[i]
    Pz = Pz / n; Pyz = Pyz / n
    Pyz_marginal = Pyz / Pz
    # compute LEEP
    leep = 0
    for images, labels in train_loader:
        images = images.to(device)
        outputs_prob = model(images).softmax(dim=1)
        likelihood = outputs_prob @ Pyz_marginal.T
        for i, l in enumerate(labels):
            leep = leep + torch.log(likelihood[i][l])
    leep = leep / n
    print('The LEEP score is {}'.format(leep))


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='PyTorch Temporal Efficient Training')
    parser.add_argument('-m', '--model', default='res18', type=str, metavar='N',
                        choices=['res18', 'res50', 'vitb16', 'vitb32'],
                        help='number of data loading workers (default: 10)')
    parser.add_argument('-d', '--dataset', default='flowers', type=str, metavar='N',
                        choices=['caltech101', 'cifar100', 'cifar100', 'flowers', 'pets', 'aircraft',
                                 'pascalvoc', 'dtd', 'foods', 'cars', 'cub', 'dogs', 'eurosat', 'sun397'],
                        help='number of data loading workers (default: 10)')
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 10)')
    parser.add_argument('--epochs', default=150, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--port', default='13330', type=str, help='number of total epochs to run')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--lr', '--learning_rate', default=0.003, type=float, metavar='LR',
                        help='initial learning rate', dest='lr')
    parser.add_argument('--seed', default=1000, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--fixfeature', default=False, type=bool, metavar='N',
                        help='if fine tune only the fc layer')
    parser.add_argument('--supcon', action='store_true',
                        help='if use supcon loss to train')
    parser.add_argument('--fewshot', default=None, type=int, metavar='N',
                        help='images per class')
    parser.add_argument("--batch", type=int, default=64,
                        help="Batch size.")
    parser.add_argument("--batch_split", type=int, default=1,
                        help="Number of batches to compute gradient on before updating weights.")
    parser.add_argument("--rep", type=int, default=1,
                        help="Number of repeats of experiments.")
    parser.add_argument("--noval", action='store_true',
                        help="do not evaluate during the training, espeicially for few shot")
    parser.add_argument("--mixup", type=float, default=0.,
                        help="use mixup to train the model")
    args = parser.parse_args()

    accs = []

    for shot in [1, 2, 4, 8, 16]:
        for r in range(args.rep):
            seed_all(args.seed + r)
            imgnet_norm = 'normal'
            train_dataset, val_dataset, num_cls = data_loaders.build_transfer_dataset(dataset=args.dataset, download=True,
                                                                                      imgnet_norm=imgnet_norm,
                                                                                      fuse=0, fewshot=shot)
            if args.model == 'res50':
                model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
                model.fc = nn.Linear(512 * 4, num_cls)
                wd = 5e-4
            elif args.model == 'vitb16':
                model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
                model.heads = nn.Linear(model.hidden_dim, num_cls)
                wd = 0.
            elif args.model == 'vitb32':
                model = vit_b_32(weights=ViT_B_32_Weights.IMAGENET1K_V1)
                model.heads = nn.Linear(model.hidden_dim, num_cls)
                wd = 0.
            elif args.model == 'res18':
                model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
                model.fc = nn.Linear(512, num_cls)
                wd = 5e-4
            else:
                raise NotImplementedError

            model.load_state_dict(torch.load('pretrained/{}_{}_f500_mixup.pt'.format(args.model, args.dataset), map_location='cpu'))
            nn.init.zeros_(model.fc.weight); nn.init.zeros_(model.fc.bias)
            model.cuda()

            print('Training dataset size: {}, validation dataset size: {}'.format(len(train_dataset), len(val_dataset)))

            micro_batch_size = args.batch // args.batch_split
            valid_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=micro_batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True, drop_last=False)

            if micro_batch_size <= len(train_dataset):
                train_loader = torch.utils.data.DataLoader(
                    train_dataset, batch_size=micro_batch_size, shuffle=True,
                    num_workers=args.workers, pin_memory=True, drop_last=False)
            else:
                # In the few-shot cases, the total dataset size might be smaller than the batch-size.
                # In these cases, the default sampler doesn't repeat, so we need to make it do that
                # if we want to match the behaviour from the paper.
                train_loader = torch.utils.data.DataLoader(
                    train_dataset, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True,
                    sampler=torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=micro_batch_size))
            device = next(model.parameters()).device
            criterion = nn.CrossEntropyLoss().cuda() if args.dataset != 'pascalvoc' else nn.BCEWithLogitsLoss().cuda()
            if args.supcon:
                criterion = SupConLoss()

            # compute leep
            # compute_leep(model, train_loader, z=1000, y=num_cls, n=len(train_dataset))
            # exit()

            if args.fixfeature:
                group = freeze_except_fc(model)
                optimizer = torch.optim.SGD(group, lr=args.lr, weight_decay=wd)
            else:
                optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=wd)

            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.epochs)
            device = next(model.parameters()).device
            scaler = GradScaler()

            best_acc = 0
            best_epoch = 0
            facc = 0
            train_accs = []
            for epoch in range(args.epochs):
                running_loss = 0
                start_time = time.time()
                if args.fixfeature:
                    model.eval()
                else:
                    model.train()
                M = len(train_loader)
                total = 0
                correct = 0
                for i, (images, labels) in enumerate(train_loader):
                    optimizer.zero_grad()
                    labels = labels.to(device)
                    images = images.to(device)

                    with autocast(device_type='cuda', dtype=torch.float16):
                        if args.mixup > 0.0:
                            mixup_l = np.random.beta(args.mixup, args.mixup)
                            images, y_a, y_b = mixup_data(images, labels, mixup_l)
                            outputs = model(images)
                            loss = mixup_criterion(criterion, outputs, y_a, y_b, mixup_l)
                        else:
                            outputs = model(images)
                            loss = criterion(outputs, labels)
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                    running_loss += loss.item()
                    total += float(labels.size(0))
                    if args.dataset == 'pascalvoc':
                        correct += get_per_class_accuracy(args, val_dataset)(outputs.cpu(), labels)
                    else:
                        _, predicted = outputs.cpu().max(1)
                        correct += float(predicted.eq(labels.cpu()).sum().item())
                print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Time elapsed: %.2f, Accuracy: %.2f'
                            % (epoch + 1, args.epochs, i + 1, M // 1, running_loss/(i+1), time.time() - start_time,
                               100 * correct / total))
                train_accs += [100 * correct / total]

                scheduler.step()
                correct = 0
                total = 0
                model.eval()

                if ((args.fewshot is not None or args.noval) and epoch < args.epochs-1) or args.supcon:
                    final_acc = 0.
                else:
                    with torch.no_grad():
                        for batch_idx, (inputs, targets) in enumerate(valid_loader):
                            inputs = inputs.to(device)
                            outputs = model(inputs)
                            mean_out = outputs.cpu()
                            total += float(targets.size(0))
                            correct += get_per_class_accuracy(args, val_dataset)(mean_out, targets)

                        final_acc = (100 * correct / total)
                        print('Test Accuracy of the model on test images: %.3f' % final_acc)

                facc = final_acc

                if best_acc < facc:
                    best_acc = facc
                    best_epoch = epoch + 1
                    print('best_acc is: {}, find in epoch: {}'.format(best_acc, best_epoch))
                    print('\n')
            accs += [facc]

    print('Final Accuracys: {}'.format([round(a, 1) for a in accs]))
