#!/usr/bin/env python
# using the mean feature of each class as the linear head
import argparse
import builtins
import os
import shutil
import copy
import ipdb
import copy
import numpy as np

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
import torchvision.models as models
import torch.nn.functional as F

import sys

sys.path.extend(['..', '.'])
from datasets.dataset_tinyimagenet import load_train, load_val_loader, num_classes_dict
from tools.store import ExperimentLogWriter
import models.builder as model_builder
import utils
from augmentations import get_aug
from models import get_model
from datasets import get_dataset

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))
model_names += ['resnet50_mlp1024_norelu_3layer', 'resnet50_mlp2048_norelu_2layer', 'resnet50_mlp2048_norelu_3layer']

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', choices=['living17', 'entity30'],
                    default='living17',
                    help='Which dataset to evaluate on.')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
                    choices=model_names,
                    help='model architecture: ' +
                         ' | '.join(model_names) +
                         ' (default: resnet50)')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch_size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=30., type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
                    help='learning rate schedule (when to drop lr by a ratio)')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight_decay', default=0., type=float,
                    metavar='W', help='weight decay (default: 0.)',
                    dest='weight_decay')
parser.add_argument('-p', '--print_freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--world_size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist_url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing_distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')

parser.add_argument('--dir', type=str, default='/tiger/u/kshen6/outs',
                    help='Directory where all of the runs are located.')
parser.add_argument('--num_per_class', type=int, default=int(1e10),
                    help='Number of images per class for getting a subset of Imagenet')
parser.add_argument('--val_every', type=int, default=5, help='How often to evaluate lincls')
parser.add_argument('--latest_only', action='store_true', help='if set, only evaluate the latest_ checkpoints')
parser.add_argument('--mpd', action='store_true', help='short hand for multi-gpu training')
parser.add_argument('--dist_url_add', default=0, type=int, help='to avoid collisions of tcp')
parser.add_argument('--specific_ckpts', nargs='*', help='filenames of specific checkpoints to evaluate')
parser.add_argument('--use_random_labels', action='store_true', help='whether to evaluate using the random labels')
parser.add_argument('--normalize', action='store_true', help='whether to evaluate using the random labels')
parser.add_argument('--power', type=int, default=1, help='the power of the preconditioner')
parser.add_argument('--use_test', action='store_true', help='whether use test data in the computation of the preconditioner')
parser.add_argument('--load_head', action='store_true', help='whether load the pretrained head')
parser.add_argument('--downstream_percentage', default=100, type=float, help='what percentage of data is used for downstream')
# parser.add_argument('--sqrtpower', action='store_true', help='whether square root the preconditioner')
best_acc1 = 0


def main():
    args = parser.parse_args()
    if args.mpd:
        args.multiprocessing_distributed = True
        args.world_size = 1
        args.rank = 0
        args.dist_url = 'tcp://127.0.0.1:' + str(10001 + args.dist_url_add)
    utils.spawn_processes(main_worker, args)


def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu
    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:
        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        utils.init_proc_group(args, ngpus_per_node)

    logger = ExperimentLogWriter(args.dir)

    # loop through checkpoints and set pre-trained
    ckpt_dir = os.path.join(args.dir, 'checkpoints')
    for fname in sorted(os.listdir(ckpt_dir)):
        if args.latest_only and not fname.startswith('latest_'): continue
        if args.specific_ckpts is not None and fname not in args.specific_ckpts: continue
        args.pretrained = os.path.join(ckpt_dir, fname)
        lineval_dir = os.path.join(args.dir, 'lin_eval_ckpt')
        if os.path.exists(lineval_dir):
            print('linear evaluation dir exists at {}, may overwrite...'.format(lineval_dir))
        eval_ckpt(
            copy.deepcopy(args),  # because args.batch_size and args.workers are changed
            ngpus_per_node,
            fname,
            logger)


def eval_ckpt(args, ngpus_per_node, ptrain_fname, logger):
    # create model
    pretrained_id = ptrain_fname.split('.')[0]
    dict_id = pretrained_id + '_lineval_mean_feature_preconditioner_power:{}'.format(args.power)
    if args.use_test:
        dict_id = dict_id + '_use_test'
    if args.load_head:
        dict_id = dict_id + '_load_head'
    if args.downstream_percentage != 100:
        dict_id = dict_id + '_downstream_percentage:{}'.format(args.downstream_percentage)
    # if args.sqrtpower:
    #     dict_id = dict_id + '_sqrtpower'
    if args.normalize:
        dict_id += '_normalize'
    dict_id += '_random_labels' if args.use_random_labels else ''
    ckpt_dir = os.path.join(args.dir, 'lin_eval_ckpt')
    os.makedirs(ckpt_dir, exist_ok=True)
    ptrain_fname += '_random_labels' if args.use_random_labels else ''

    logger.create_data_dict(
        ['epoch', 'train_acc', 'val_acc', 'train_loss', 'val_loss', 'train5', 'val5'],
        dict_id=dict_id)

    model = model_builder.get_model(num_classes_dict[args.dataset], arch=args.arch)
    mean_feature = torch.zeros(model.fc.weight.shape).to('cuda')
    preconditioner = torch.zeros([model.fc.weight.shape[1], model.fc.weight.shape[1]]).to('cuda')
    model.fc = nn.Identity()
    # freeze all layers
    for name, param in model.named_parameters():
        param.requires_grad = False

    # load from pre-trained, before DistributedDataParallel constructor
    if args.pretrained:
        if os.path.isfile(args.pretrained):
            checkpoint = torch.load(args.pretrained, map_location='cpu')
            state_dict = checkpoint['state_dict']
            model_builder.load_checkpoint(model, state_dict, args.pretrained, args=args, load_pretrained_head=args.load_head)
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrained))

    model = utils.init_data_parallel(args, model, ngpus_per_node)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    cudnn.benchmark = True

    # Data loading code
    if args.use_random_labels:
        random_labels = torch.load(os.path.join(args.dir, 'saved_tensors', 'random_labels.pth')).numpy()
    else:
        random_labels = None
    train_sampler, train_loader = load_train(args.dataset, args.num_per_class, args.distributed,
                                             args.batch_size, args.workers, data_aug='mocov2',
                                             random_labels=random_labels, percentage=args.downstream_percentage)
    _, all_train_loader = load_train(args.dataset, args.num_per_class, args.distributed,
                                             args.batch_size, args.workers, data_aug='mocov2',
                                             random_labels=random_labels, percentage=100)

    val_loader = load_val_loader(args.dataset, args.batch_size, args.workers)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    best_acc1 = 0

    preconditioner = train_preconditioner(all_train_loader, model, args, mean_feature)
    if args.use_test:
        preconditioner = test_preconditioner(val_loader, model, args, mean_feature, preconditioner)
    eigval, _ = torch.lobpcg(preconditioner, 1, largest=True)
    preconditioner /= eigval

    mean_feature = compute_feature(train_loader, model, args, mean_feature)

    # train for one epoch
    top1, top5, losses = train(all_train_loader, model, criterion, args, mean_feature, preconditioner)
    # ipdb.set_trace()

    # evaluate on validation set
    acc1, acc5, val_losses = validate(val_loader, model, criterion, args, mean_feature, preconditioner)

    if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                                                and args.rank % ngpus_per_node == 0):
        logger.update_data_dict(
            {
                'epoch': int(0) + 1,
                'train_acc': top1.item(),
                'val_acc': acc1.item(),
                'train_loss': losses,
                'val_loss': val_losses,
                'train5': top5.item(),
                'val5': acc5.item()
            }, dict_id=dict_id)
        logger.save_data_dict(dict_id=dict_id)


def compute_feature(train_loader, model, args, mean_feature):
    mean_feature = torch.zeros(mean_feature.shape).to('cuda')

    """
    Switch to eval mode:
    Under the protocol of linear classification on frozen features/models,
    it is not legitimate to change any part of the pre-trained model.
    BatchNorm in train mode may revise running mean/std (even if it receives
    no gradient), which are part of the model parameters too.
    """
    model.eval()

    for i, (images, target) in enumerate(train_loader):
        # print(i)
        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
        else:
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        # compute output
        iter_feature = torch.zeros(mean_feature.shape).to('cuda')
        # ipdb.set_trace()
        output_feature = model(images)
        # if args.normalize:
        #    output_feature = F.normalize(output_feature, dim=1)
        # ipdb.set_trace()
        labels = target.cpu().numpy()
        for class_id in range(num_classes_dict[args.dataset]):
            if np.where(labels == class_id)[0].size!=0:
                iter_feature[class_id] = output_feature[np.where(labels == class_id)].mean(dim=0)
        mean_feature = mean_feature * i / (i + 1) + iter_feature / (i + 1)

    return mean_feature


def train(all_train_loader, model, criterion, args, mean_feature, preconditioner):
    losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')
    progress = utils.ProgressMeter(
        len(all_train_loader),
        [losses, top1, top5],
        prefix="Epoch: [{}]".format(0))


    """
    Switch to eval mode:
    Under the protocol of linear classification on frozen features/models,
    it is not legitimate to change any part of the pre-trained model.
    BatchNorm in train mode may revise running mean/std (even if it receives
    no gradient), which are part of the model parameters too.
    """
    model.eval()

    for i, (images, target) in enumerate(all_train_loader):
        # print(i)
        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
        else:
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        output_feature = model(images)
        labels = target.cpu().numpy()
        feature_to_pred = mean_feature
        # if args.normalize:
        #    feature_to_pred = F.normalize(feature_to_pred, dim=1)
        preconditioner_tmp = torch.matrix_power(preconditioner, args.power)
        # if args.sqrtpower:
        #     preconditioner_tmp = preconditioner_tmp.sqrt()
        #ipdb.set_trace()
        output = torch.matmul(torch.matmul(output_feature, preconditioner_tmp), feature_to_pred.T)

        # ipdb.set_trace()
        loss = criterion(output, target)
        #print(torch.isnan(torch.max(iter_feature)), torch.max(iter_feature), torch.max(epoch_mean_feature))
        # ipdb.set_trace()
        # measure accuracy and record loss
        if 'imagenet-2class' in args.dataset:
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 1))
        else:
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        if (i + 1) % args.print_freq == 0:
            progress.display(i)

    return top1.avg, top5.avg, losses.avg


def validate(val_loader, model, criterion, args, mean_feature, preconditioner):
    #losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')
    #progress = utils.ProgressMeter(
    #    len(val_loader),
    #    [losses, top1, top5],
    #    prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output_feature = model(images)
            # ipdb.set_trace()

            preconditioner_tmp = torch.matrix_power(preconditioner, args.power)
            # ipdb.set_trace()
            # if args.sqrtpower:
            #     evals, evecs = torch.eig(preconditioner_tmp, eigenvectors = True)
            #     evals = evals[:, 0]
            #     evpow = evals ** (1 / 2)
            #     preconditioner_tmp = torch.matmul(evecs.double(), torch.matmul(torch.diag(evpow).double(), torch.inverse(evecs).double()))

            output_feature = torch.matmul(output_feature, preconditioner_tmp)
            #ipdb.set_trace()
            feature_to_pred = mean_feature
            if args.normalize:
                feature_to_pred = F.normalize(feature_to_pred, dim=1)
            output = torch.matmul(output_feature, feature_to_pred.T)
            #loss = criterion(output, target)

            # measure accuracy and record loss
            if 'imagenet-2class' in args.dataset:
                acc1, acc5 = utils.accuracy(output, target, topk=(1, 1))
            else:
                acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            #if i % args.print_freq == 0:
            #    progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        # is the above todo done??
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg, top5.avg, 0  #losses.avg


def train_preconditioner(all_train_loader, model, args, mean_feature):
    model.eval()

    epoch_preconditioner = torch.zeros([mean_feature.shape[1],mean_feature.shape[1]]).to('cuda')

    with torch.no_grad():
        for i, (images, target) in enumerate(all_train_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)

            # compute output
            output_feature = model(images)
            epoch_preconditioner = epoch_preconditioner * i / (i + 1) + output_feature.T @ output_feature / (i + 1)

    return epoch_preconditioner


def test_preconditioner(val_loader, model, args, mean_feature, preconditioner):
    model.eval()

    epoch_preconditioner = torch.zeros([mean_feature.shape[1],mean_feature.shape[1]]).to('cuda')

    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)

            # compute output
            output_feature = model(images)
            epoch_preconditioner = epoch_preconditioner * i / (i + 1) + output_feature.T @ output_feature / (i + 1)

    preconditioner = preconditioner / 2 + epoch_preconditioner / 2

    return preconditioner


if __name__ == '__main__':
    main()
