import os

import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np

from unlabeled_extrapolation.models.imnet_resnet import ResNet50

def get_classes_to_compare(num_classes, num_to_choose, seed):
    prng = np.random.RandomState(seed)
    classes = []
    for _ in range(num_to_choose):
        while True:
            class_1, class_2 = prng.choice(num_classes, size=2, replace=False)
            curr_pair = sorted([class_1, class_2])
            if curr_pair not in classes:
                classes.append(curr_pair)
                break
    return classes

def main_loop(train_ds, test_ds, save_dir, identifier, args,
              save_model=False):
    base_file_name = os.path.join(save_dir, identifier)

    if os.path.exists(f'{base_file_name}-final'):
        print(f'Already completed {base_file_name}, skipping...')
        return

    if args.linear_probe_only:
        if args.swav_dir is not None:
            # only allow the last FC layer to be trainable
            # must load ResNet50
            ckpt_path = os.path.join(args.swav_dir, 'checkpoints', f'ckp-{args.swav_ckpt}.pth')
            model = ResNet50(pretrained=True, pretrain_style='swav', checkpoint_path=ckpt_path)
            model.set_requires_grad(False)
            model.new_last_layer(num_classes=2)
            model.get_last_layer().requires_grad = True
            model.cuda()
        else:
            model = models.resnet50()
            in_features = model.fc.in_features
            model.fc = nn.Identity()
            sd = torch.load(args.sentry_ft)
            sd = {key.replace('conv_params.', ''): value for key, value in sd.items()}
            missing, _ = model.load_state_dict(sd, strict=False)
            assert len(missing) == 0
            for param in model.parameters():
                param.requires_grad = False
            model.fc = nn.Linear(in_features, 2)
            model.fc.requires_grad = True
            model = model.cuda()
    else:
        model = models.__dict__[args.arch](num_classes=2)
        model = model.cuda()

    if len(train_ds) < 300: # for sufficiently small datasets, do 80 epochs only
        epochs = int(args.epochs * 4/5)
    else:
        epochs = args.epochs

    if getattr(args, 'continue_from', None) is not None:
        previous = torch.load(args.continue_from)
        state_dict = previous['model']
        model.load_state_dict(state_dict)

    # get objective, optimizer, lr scheduler
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    continue_from_epoch = getattr(args, 'continue_from_epoch', -1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, last_epoch=continue_from_epoch)

    # get data loaders
    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    train_acc = []
    test_acc = []
    for epoch in range(epochs):
        # train for one epoch
        train_acc.append(train_epoch(train_loader, model, criterion, optimizer, epoch, args))
        scheduler.step()

        # evaluate on test set
        if (epoch + 1) % 5 == 0:
            test_acc.append(validate(test_loader, model, criterion, args))

        if (epoch + 1) % args.save_freq == 0:
            to_save = {
                'train_accs': train_acc,
                'test_accs': test_acc,
            }
            if save_model:
                to_save['model'] = model.state_dict()
            torch.save(to_save, f'{base_file_name}-{epoch}')
            previous_file = f'{base_file_name}-{epoch - args.save_freq}'
            if os.path.exists(previous_file):
                os.remove(previous_file)

    previous_file = f'{base_file_name}-{epochs - 1}'
    if os.path.exists(previous_file):
        os.remove(previous_file)
    to_save = {
        'train_accs': train_acc,
        'test_accs': test_acc,
    }
    if save_model:
        to_save['model'] = model.state_dict()
    torch.save(to_save, f'{base_file_name}-final')

def train_epoch(train_loader, model, criterion, optimizer, epoch, args):
    loss_meter = AverageMeter('Loss', ':.4e')
    acc_meter = AverageMeter('Acc', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [loss_meter, acc_meter],
        prefix=f'Epoch: [{epoch}]')

    model.train()
    for i, (images, target) in enumerate(train_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc = accuracy(output, target)
        loss_meter.update(loss.item(), images.size(0))
        acc_meter.update(acc[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % args.print_freq == 0:
            progress.display(i)
    return acc_meter.avg

def validate(test_loader, model, criterion, args):
    loss_meter = AverageMeter('Loss', ':.4e')
    acc_meter = AverageMeter('Acc', ':6.2f')
    progress = ProgressMeter(
        len(test_loader),
        [loss_meter, acc_meter],
        prefix='Test: ')

    model.eval()
    with torch.no_grad():
        for i, (images, target) in enumerate(test_loader):
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc = accuracy(output, target)
            loss_meter.update(loss.item(), images.size(0))
            acc_meter.update(acc[0], images.size(0))

            if (i + 1) % args.print_freq == 0:
                progress.display(i)

        print(' * Test Acc {acc.avg:.3f}'.format(acc=acc_meter))
    return acc_meter.avg

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append((correct_k.mul_(100.0 / batch_size)).item())
        return res

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
