import os
import sys
import time
import glob
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils

from torch.autograd import Variable
from networks import get_model, get_mask_model
from dataset import get_num_class, get_dataloaders, get_img_size
from utils import reproducibility, load_model
from cam import CamModel
import networks.masknet as mn
from KMixAugmentor import KMixAugmentor

parser = argparse.ArgumentParser("tmix")

parser.add_argument('--dataroot', type=str,
                    default='./data', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='name of dataset')
parser.add_argument('--train_ratio', type=float,
                    default=0.5, help='ratio of training data, 1 means no validation')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--num_workers', type=int, default=0, help="num_workers")

parser.add_argument('--learning_rate', type=float,
                    default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float,
                    default=0.0001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float,
                    default=3e-4, help='weight decay')
parser.add_argument('--grad_clip', type=float,
                    default=5, help='gradient clipping')

parser.add_argument('--use_cuda', type=bool, default=True,
                    help="use cuda default True")
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--use_parallel',  action='store_true', default=False, help="use data parallel default False")

parser.add_argument('--layers', type=int, default=20,
                    help='total number of layers')
parser.add_argument('--model_name', type=str,
                    default='resnet18', help="model_name")
parser.add_argument('--model_path', type=str,
                    default='saved_models', help='path to save the model')

parser.add_argument('--cutout', action='store_true',
                    default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int,
                    default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float,
                    default=0.2, help='drop path probability')

parser.add_argument('--epochs', type=int, default=600,
                    help='num of training epochs')
parser.add_argument('--report_freq', type=float,
                    default=50, help='report frequency')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random seed')

parser.add_argument('--k', type=int, default=2, help='mixing coeificcient')
parser.add_argument('--mix_alpha', type=float, default=1.0, help='mixing coeificcient')
parser.add_argument('--layer_mix', type=int, default=4, help='mixing layer')
parser.add_argument('--s_size', type=int, default=-1, help='saliency size, -1 means match image size')

parser.add_argument('--pretrained_model_path', type=str, default='./', help='mask model path for camix')
parser.add_argument('--mask_net_path', type=str, default='./', help='mask net path for transfer')
parser.add_argument('--restore_path', type=str, default='./', help='model path')
parser.add_argument('--restore', action='store_true', default=False, help='restore model')
parser.add_argument('--mask_n_channel', type=int, default=4, help='mask net n_channel')
parser.add_argument('--masknet_model', type=str, default='./', help='mask net model name')

parser.add_argument('--mix_background', action='store_true', default=False, help='restore model')

args = parser.parse_args()
debug = True if args.save == "debug" else False

args.save = '{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), args.save)
if debug:
    args.save = os.path.join('debug', args.save)
else:
    args.save = os.path.join('augmentor', args.dataset, args.save)
# writer = SummaryWriter(f'{args.save}/board')
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

if args.dataset == 'imagenet_search':
    args.crop_size = 128
    args.dataroot = './data'
    # args.s_size = 128
else:
    args.crop_size=32

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

CAM_METHOD = 'simcam'
# CAM_METHOD = 'saliency'


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def k_mixup_criterion(criterion, pred, y, index_list, lam_list):
    loss = 0
    for lam, index in zip(lam_list, index_list):
        loss += lam * criterion(pred, y[index])
    return loss

def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    torch.cuda.set_device(args.gpu)
    reproducibility(args.seed)

    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    n_class = get_num_class(args.dataset)
    pretrained_model = get_model(
                    model_name=args.model_name,
                    num_class=n_class,
                    datamixer=None,
                    use_cuda=args.use_cuda,
                    data_parallel=args.use_parallel)
    if 'imagenet' not in args.dataset:
        pretrained_model = load_model(pretrained_model, args.pretrained_model_path, location=args.gpu)
    task_model = pretrained_model
    task_model = task_model.cuda()
    task_model.eval()
    logging.info(f'Loading mask model from {args.pretrained_model_path}')

    cam_model = CamModel(pretrained_model, args.layer_mix, n_class, cam_method=CAM_METHOD)
    masknet = get_mask_model(args.masknet_model, args.mask_n_channel, args.k)
    if args.dataset == 'imagenet_search':
        stn_size = -1
        s_size = 64
    else:
        stn_size = -1
        s_size = -1
    mixaugmentor = KMixAugmentor(cam_model, masknet, k=args.k, img_size=get_img_size(args.dataset), s_size=s_size, stn_size=stn_size).cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(mixaugmentor))
    criterion = nn.CrossEntropyLoss()

    train_queue, valid_queue, test_queue = get_dataloaders(
                                                dataset=args.dataset,
                                                batch=args.batch_size,
                                                num_workers=args.num_workers,
                                                dataroot=args.dataroot,
                                                cutout=args.cutout,
                                                cutout_length=args.cutout_length,
                                                train_ratio=args.train_ratio,
                                                split_idx=0,
                                                target_lb=-1,
                                                use_autoDA=False,
                                                crop_size=args.crop_size)
    logging.info(f'Dataset: {args.dataset}')
    logging.info(f'  |total: {len(train_queue.dataset)}')
    logging.info(f'  |train: {len(train_queue)*args.batch_size}')
    logging.info(f'  |valid: {len(valid_queue)*args.batch_size}')

    optimizer = torch.optim.SGD(
                    mixaugmentor.parameters(),
                    args.learning_rate,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay,
                    nesterov=True)

    if args.restore:
        pass
        # trained_epoch = utils.restore_ckpt(mixaugmentor, optimizer, warmup_scheduler, args.restore_path, location=args.gpu) + 1
        # n_epoch = args.epochs - trained_epoch
        # logging.info(f'Restoring model from {args.restore_path}, starting from epoch {trained_epoch}')
    else:
        trained_epoch = 0
        n_epoch = args.epochs

    test_acc, test_obj, test_acc5, _ = infer(test_queue, task_model, criterion, mixaugmentor)
    logging.info('test_acc %f %f', test_acc, test_acc5)
    logging.info(f'save to {args.save}')

    for i_epoch in range(n_epoch):
        epoch = trained_epoch + i_epoch
        lr = optimizer.param_groups[0]['lr']
        logging.info('epoch %d lr %e', epoch, lr)

        train_acc, train_obj = train(
            train_queue, task_model, criterion, optimizer, epoch, mixaugmentor)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj, _, _ = infer(valid_queue, task_model, criterion, mixaugmentor)
        logging.info('valid_acc %f', valid_acc)

        test_acc, test_obj, test_acc5, _ = infer(test_queue, task_model, criterion, mixaugmentor)
        logging.info('test_acc %f %f', test_acc, test_acc5)
        utils.save_ckpt(mixaugmentor, optimizer, None, epoch, os.path.join(args.save, f'e{epoch}_weights.pt'))

    test_acc, test_obj, test_acc5, _ = infer(test_queue, task_model, criterion, mixaugmentor)
    utils.save_ckpt(mixaugmentor, optimizer, None, epoch, os.path.join(args.save, 'weights.pt'))
    logging.info('test_acc %f %f', test_acc, test_acc5)
    logging.info(f'{args.model_name} {args.epochs} {args.mix_alpha} {args.masknet_model}@{args.mask_n_channel} \
        {args.learning_rate} {test_acc} {test_acc5}')
    logging.info(f'save to {args.save}')


def train(train_queue, model, criterion, optimizer, epoch, mixaugmentor):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()
    mixaugmentor.train()
    for step, (input, target) in enumerate(train_queue):
        target = target.cuda(non_blocking=True)
        input = input.cuda()
        optimizer.zero_grad()
        if args.mix_background:
            mixed_input, targets = mixaugmentor.mix_data_background(input, target)
            logits = model(mixed_input)
            loss = criterion(logits, targets)
            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
        else:
            mixed_input, targets, index_list, lam_list = mixaugmentor.mix_data(input, target, args.mix_alpha)
            logits = model(mixed_input)
            loss = k_mixup_criterion(criterion, logits, targets, index_list, lam_list)
            prec1, prec5 = utils.mix_k_accuracy(logits, targets, index_list, lam_list, topk=(1, 5))

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        n = input.size(0)
        objs.update(loss.detach().item(), n)
        top1.update(prec1.detach().item(), n)
        top5.update(prec5.detach().item(), n)

        global_step = step + epoch * len(train_queue)

        if global_step % args.report_freq == 0:
            logging.info('train %03d %e %f %f', global_step, objs.avg, top1.avg, top5.avg)

    return top1.avg, objs.avg


def infer(valid_queue, model, criterion, mixaugmentor):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    mixaugmentor.eval()
    model.eval()
    for step, (input, target) in enumerate(valid_queue):
        input = Variable(input).cuda()
        target = Variable(target).cuda(non_blocking=True)

        mixed_input, targets, index_list, lam_list = mixaugmentor.mix_data(input, target, args.mix_alpha)
        logits = model(mixed_input)

        loss = k_mixup_criterion(criterion, logits, targets, index_list, lam_list)

        prec1, prec5 = utils.mix_k_accuracy(logits, targets, index_list, lam_list, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.detach().item(), n)
        top1.update(prec1.detach().item(), n)
        top5.update(prec5.detach().item(), n)

    return top1.avg, objs.avg, top5.avg, objs.avg


if __name__ == '__main__':
    main()
