from argparse import ArgumentParser
from functools import partial
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision

import ignite.distributed as idist

from datasets import load_datasets
from utils import Logger

def collect_features(backbone,
                     dataloader,
                     device,
                     normalize=True,
                     dst=None,
                     verbose=False):

    if dst is None:
        dst = device

    backbone.eval()
    with torch.no_grad():
        features = []
        labels   = []
        for i, (x, y) in enumerate(dataloader):
            if x.ndim == 5:
                _, n, c, h, w = x.shape
                x = x.view(-1, c, h, w)
                y = y.view(-1, 1).repeat(1, n).view(-1)
            z = backbone(x.to(device))
            if normalize:
                z = F.normalize(z, dim=-1)
            features.append(z.to(dst).detach())
            labels.append(y.to(dst).detach())
            if verbose and (i+1) % 10 == 0:
                print(i+1)
        features = idist.utils.all_gather(torch.cat(features, 0).detach())
        labels   = idist.utils.all_gather(torch.cat(labels, 0).detach())

    return features, labels

def build_step(X, Y, classifier, optimizer, w):
    def step():
        optimizer.zero_grad()
        loss = F.cross_entropy(classifier(X), Y, reduction='sum')
        for p in classifier.parameters():
            loss = loss + p.pow(2).sum().mul(w)
        loss.backward()
        return loss
    return step


def compute_accuracy(X, Y, classifier, metric):
    with torch.no_grad():
        preds = classifier(X).argmax(1)
        if metric == 'top1':
            acc = (preds == Y).float().mean().item()
        elif metric == 'class-avg':
            total, count = 0., 0.
            for y in range(0, Y.max().item()+1):
                masks = Y == y
                if masks.sum() > 0:
                    total += (preds[masks] == y).float().mean().item()
                    count += 1
            acc = total / count
        else:
            raise Exception(f'Unknown metric: {metric}')
    return acc


def main(local_rank, args):
    cudnn.benchmark = True
    device = idist.device()
    logger = Logger(None)

    # DATASETS
    datasets = load_datasets(dataset=args.dataset,
                             datadir=args.datadir,
                             pretrain_data=args.pretrain_data)
    build_dataloader = partial(idist.auto_dataloader,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers,
                               shuffle=True,
                               pin_memory=True)
    trainloader = build_dataloader(datasets['train'], drop_last=False)
    valloader   = build_dataloader(datasets['val'],   drop_last=False)
    testloader  = build_dataloader(datasets['test'],  drop_last=False)
    num_classes = datasets['num_classes']

    # MODELS
    ckpt = torch.load(args.ckpt, map_location=device)
    state_dict = ckpt['state_dict']
    
    if args.method == 'moco':
        for k in list(state_dict.keys()):
            # retain only encoder_q up to before the embedding layer
            if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                # remove prefix
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]

    elif args.method == 'supcon':
        for k in list(state_dict.keys()):
            # retain only encoder_q up to before the embedding layer
            if k.startswith('module.encoder') and not k.startswith('module.encoder.fc'):
                # remove prefix
                state_dict[k[len("module.encoder."):]] = state_dict[k]
            del state_dict[k]

    elif args.method == 'siam':
        for k in list(state_dict.keys()):
            # retain only encoder_q up to before the embedding layer
            if k.startswith('module.backbone') and not k.startswith('module.backbone.fc'):
                # remove prefix
                state_dict[k[len("module.backbone."):]] = state_dict[k]
            # delete renamed or unused k
            del state_dict[k]
    else:
        print("==> your method is currently not supported")
        exit()

    backbone = torchvision.models.__dict__[args.model]()
    backbone.fc = nn.Identity()
    backbone.load_state_dict(state_dict, strict=False)
    build_model = partial(idist.auto_model, sync_bn=True)
    backbone = build_model(backbone)
    print(args.num_backbone_features)
    # EXTRACT FROZEN FEATURES
    logger.log_msg('collecting features ...')
    X_train, Y_train = collect_features(backbone, trainloader, device, normalize=False)
    X_val,   Y_val   = collect_features(backbone, valloader,   device, normalize=False)
    X_test,  Y_test  = collect_features(backbone, testloader,  device, normalize=False)
    classifier = nn.Linear(args.num_backbone_features, num_classes).to(device)
    optim_kwargs = {
        'line_search_fn': 'strong_wolfe',
        'max_iter': 5000,
        'lr': 1.,
        'tolerance_grad': 1e-10,
        'tolerance_change': 0,
    }
    logger.log_msg('collecting features ... done')

    best_acc = 0.
    best_w = 0.
    best_classifier = None
    for w in torch.logspace(-6, 5, steps=45).tolist():
        optimizer = optim.LBFGS(classifier.parameters(), **optim_kwargs)
        optimizer.step(build_step(X_train, Y_train, classifier, optimizer, w))
        acc = compute_accuracy(X_val, Y_val, classifier, args.metric)

        if best_acc < acc:
            best_acc = acc
            best_w = w
            best_classifier = deepcopy(classifier)

        logger.log_msg(f'w={w:.4e}, acc={acc:.4f}')

    logger.log_msg(f'BEST: w={best_w:.4e}, acc={best_acc:.4f}')

    X = torch.cat([X_train, X_val], 0)
    Y = torch.cat([Y_train, Y_val], 0)
    optimizer = optim.LBFGS(best_classifier.parameters(), **optim_kwargs)
    optimizer.step(build_step(X, Y, best_classifier, optimizer, best_w))
    acc = compute_accuracy(X_test, Y_test, best_classifier, args.metric)
    logger.log_msg(f'test acc={acc:.4f}')

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--ckpt', type=str, required=True)
    parser.add_argument('--pretrain-data', type=str, default='imagenet100')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--datadir', type=str, default='/data')
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--num-workers', type=int, default=4)
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--print-freq', type=int, default=10)
    parser.add_argument('--distributed', action='store_true')
    parser.add_argument('--metric', type=str, default='top1')
    parser.add_argument('--method', type=str, default='siam')
    args = parser.parse_args()
    args.backend = 'nccl' if args.distributed else None
    args.num_backbone_features = 2048 #if args.model.endswith('resnet18') else 2048
    with idist.Parallel(args.backend) as parallel:
        parallel.run(main, args)