""""train models with coresets and per-batch approximations"""
import argparse
import os
import glob
import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import models

from torch.utils.data import DataLoader
from warnings import simplefilter
from utils import *

# add logger and tensorboard
import logging
from torch.utils.tensorboard import SummaryWriter

# import subset generator
from train_helpers import *

# ignore all future warnings
simplefilter(action='ignore', category=FutureWarning)
np.seterr(all='ignore')

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(models.__dict__[name]))

print(model_names)

parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 
                    choices=['resnet18', 'resnet34', 'resnet50'])
parser.add_argument('--dataset', default='cifar10', 
                    choices=('cifar10', 'cifar100'), help='dataset')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=128, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', '-m', type=float, metavar='M', default=0.9,
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--print-freq', '-p', default=100, type=int,
                    metavar='N', help='print frequency (default: 20)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--half', dest='half', action='store_true',
                    help='use half-precision(16-bit) ')
parser.add_argument('--save-dir', dest='save_dir',
                    help='The directory used to save the trained models',
                    default='./save_temp', type=str)
parser.add_argument('--save-every', dest='save_every',
                    help='Saves checkpoints at every specified number of epochs',
                    type=int, default=20)  # default=10)
parser.add_argument('--gpu', default='7', type=str, help='The GPU to be used')
parser.add_argument('--ig', type=str, help='ig method', default='sgd', choices=['sgd, adam, adagrad'])
parser.add_argument('--lr_schedule', '-lrs', type=str, help='learning rate schedule', default='mile',
                    choices=['mile', 'exp', 'cnt', 'step', 'cosine', 'plateau'])
parser.add_argument('--gamma', type=float, default=0.2, help='learning rate decay parameter')
parser.add_argument('--warm', '-w', dest='warm_start', action='store_true', help='warm start learning rate ')
parser.add_argument('--no_transforms', action='store_true', help='disable random transformations of the data')
parser.add_argument('--match_num_iter', action='store_true', help='use the same number of iterations for the entire training')
parser.add_argument('--seed', default=11111111, type=int, help="random seed")
parser.add_argument('--run', default=0, type=int, help="run_num")

logger = logging.getLogger(__name__)

def main(args):
    global best_prec1, best_loss

    args.save_dir = os.path.join(args.save_dir, args.resume.split('/')[-1])
    args.save_dir += f'_transfer_{args.dataset}'
    args.save_dir += f'_match' if args.match_num_iter else ''
    args.save_dir += f'_{args.arch}_bs_{args.batch_size}_{args.ig}_lr_{args.lr}_{args.lr_schedule}_gamma_{args.gamma:.1f}_wd_{args.weight_decay}'
    args.save_dir += '_no_transforms' if args.no_transforms else ''

    print(args.save_dir)
    os.makedirs(args.save_dir)

    args.logger = set_logger(args, logger)
    args.logger.info(f'--------- method: {args.ig}, moment: {args.momentum}, '
          f'lr_schedule: {args.lr_schedule}---------------')

    args.writer = SummaryWriter(args.save_dir)

    set_seed(args)

    if args.dataset == 'cifar100':
        args.num_classes = 100
        args.class_num = 100
    else:
        args.num_classes = 10
        args.class_num = 10

    cudnn.benchmark = True

    train_transform, val_transform, _ = get_transforms(args)

    if args.no_transforms:
        indexed_dataset = IndexedDataset(args.dataset, train=True, transform=val_transform)
    else:
        indexed_dataset = IndexedDataset(args.dataset, train=True, transform=train_transform)

    args.train_num = len(indexed_dataset)

    indexed_loader = DataLoader(
        indexed_dataset,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    ori_arch = args.resume.split('/')[-1].split('_')[2]
    if ori_arch[:6] != 'resnet':
        ori_arch = args.resume.split('/')[-1].split('_')[3]

    if ori_arch == args.arch:
        assert args.resume.split('/')[-1].split('_')[0] != args.dataset
        logger.info(f"Transferring from {args.resume.split('/')[-1].split('_')[0]} to {args.dataset} with trained {args.arch}")
        logger.info("==> Loading trained model..")
        if args.resume.split('/')[-1].split('_')[0] == 'cifar100':
            model = torch.nn.DataParallel(models.__dict__[args.arch](num_classes=100))
        else:
            model = torch.nn.DataParallel(models.__dict__[args.arch](num_classes=10))
        try:
            ckpt_path = os.path.join(args.resume, 'final.pth')
            checkpoint = torch.load(ckpt_path, map_location='cpu')
            logger.info(f"==> Loaded trained model {ckpt_path}")
        except:
            ckpt_path = os.path.join(args.resume, 'final_0.pth')
            checkpoint = torch.load(ckpt_path, map_location='cpu')
            logger.info(f"==> Loaded trained model {ckpt_path}")
        model.load_state_dict(checkpoint['state_dict'])
        model.cuda()

        logger.info("==> Freezing the feature representation..")
        for param in model.parameters():
            param.requires_grad = False
        logger.info("==> Reinitializing the classifier..")
        num_ftrs = model.module.linear.in_features
        model.module.linear = nn.Linear(num_ftrs, args.num_classes).cuda()  # requires grad by default

        args.logger.info('Training on all the data')
        train_loader = indexed_loader
    else:
        assert args.resume.split('/')[-1].split('_')[0] == args.dataset
        logger.info(f"Transferring from {args.resume.split('/')[-1].split('_')[2]} to {args.arch} on selected {args.dataset}")
        model = torch.nn.DataParallel(models.__dict__[args.arch](num_classes=args.num_classes))

        if args.resume.split('/')[1] == 'output_record':
            args.logger.info('Training on all the data')
            train_loader = indexed_loader
        else:
            if 'traject' in args.resume.split('/')[-1]:
                if os.path.isdir(args.resume):
                    selected = np.zeros(args.train_num)
                    for subset_file in glob.glob(f'{args.resume}/subset_epoch_*.pth'):
                        subset = torch.load(subset_file)
                        selected[subset] += 1
                    args.logger.info("=> loaded subset from {}".format(args.resume))
                    subset = np.where(selected>0)[0]
                else:
                    raise NotADirectoryError("=> no checkpoint found at '{}'".format(args.resume))
            else:
                try:
                    frac = float(args.resume.split('/')[-1].split('_')[4])
                except:
                    frac = float(args.resume.split('/')[-1].split('_')[5])
                B = int(args.train_num * frac)
                subset = np.arange(0, args.train_num)
                np.random.shuffle(subset)
                subset = subset[:B]
            indexed_subset = torch.utils.data.Subset(indexed_dataset, indices=subset)
            train_loader = DataLoader(
                indexed_subset,
                batch_size=args.batch_size, shuffle=True,
                num_workers=args.workers, pin_memory=True)
            args.logger.info(f'{len(subset)} examples in the selected subset')
            torch.save(subset, os.path.join(args.save_dir, f'subset_epoch_0.pth'))

    val_loader = torch.utils.data.DataLoader(
        IndexedDataset(args.dataset, transform=val_transform, train=False),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    train_val_loader = torch.utils.data.DataLoader(
        IndexedDataset(args.dataset, transform=val_transform, train=True),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.half:
        model.half()
        train_criterion.half()
        val_criterion.half()

    if args.evaluate:
        validate(val_loader, model, val_criterion)
        return

    optimizer, lr_scheduler = get_optimizer_and_scheduler(args, model)

    if args.warm_start:
        args.logger.info('Warm start learning rate')
        lr_scheduler_f = GradualWarmupScheduler(optimizer, 1.0, 20, lr_scheduler)
    else:
        args.logger.info('No Warm start')
        lr_scheduler_f = lr_scheduler        
        
    best_prec1, best_loss = 0, 1e10
    start_epoch = 0
    
    end_epoch = args.epochs
    weight = None
    train_loss, test_loss = np.zeros(end_epoch), np.zeros(end_epoch)
    train_acc, test_acc = np.zeros(end_epoch), np.zeros(end_epoch)
    train_time, data_time = np.zeros(end_epoch), np.zeros(end_epoch)

    epoch = start_epoch
    while epoch < end_epoch:
        lr = lr_scheduler_f.optimizer.param_groups[0]['lr']
        losses, preds = get_losses_and_preds(args, train_val_loader, model, train_criterion)
        torch.save({'loss': losses, 'preds': preds}, os.path.join(args.save_dir, f'grad_epoch_{epoch}.pth'))

        data_time[epoch], train_time[epoch] = train(args, train_loader, model, train_criterion, optimizer, epoch, weight)

        # evaluate on validation set
        prec1, loss = validate(args, val_loader, model, val_criterion)
        best_prec1 = max(prec1, best_prec1)
        test_acc[epoch], test_loss[epoch] = prec1, loss

        if args.lr_schedule == 'plateau':
            lr_scheduler_f.step(loss)
        else:
            lr_scheduler_f.step()

        ta, tl = validate(args, train_val_loader, model, val_criterion)
        best_loss = min(tl, best_loss)
        train_acc[epoch], train_loss[epoch] = ta, tl

        if end_epoch == args.epochs:
            args.writer.add_scalar('test/1.test_acc_by_epoch', prec1, epoch)
            args.writer.add_scalar('test/2.test_loss_by_epoch', loss, epoch)
            args.writer.add_scalar('train/1.train_acc_by_epoch', ta, epoch)
            args.writer.add_scalar('train/2.train_loss_by_epoch', tl, epoch)

            args.writer.add_scalar('test/3.test_acc_by_time', prec1, int(np.sum(train_time)+np.sum(data_time)))
            args.writer.add_scalar('test/4.test_loss_by_time', loss, int(np.sum(train_time)+np.sum(data_time)))
            args.writer.add_scalar('train/3.train_acc_by_time', ta, int(np.sum(train_time)+np.sum(data_time)))
            args.writer.add_scalar('train/4.train_loss_by_time', tl, int(np.sum(train_time)+np.sum(data_time)))

        args.logger.info(f'epoch: {epoch}, prec1: {prec1}, loss: {tl:.3f}, '
            f'lr: {lr:.3f}, best_prec1: {best_prec1}, best_loss: {best_loss:.3f}')

        epoch += 1

    save_checkpoint({'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': lr_scheduler_f.state_dict(),
        'best_prec1': best_prec1,
        'best_loss': best_loss,
        }, filename=os.path.join(args.save_dir, 'final.pth'))

if __name__ == '__main__':
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    main(args)
