import os
import argparse
import logging

import torch
import torch.nn.functional as F
import torch.optim as optim

import random
import numpy as np

from utils.check_dataset import check_dataset
from utils.check_model_l2t import check_model
from utils.common import AverageMeter, accuracy, set_logging_config
from utils.meta_optimizers import MetaSGD

torch.backends.cudnn.benchmark = True


def _get_num_features(model):
    if model.startswith('resnet'):
        n = int(model[6:])
        if n in [18, 34, 50, 101, 152]:
            return [64, 64, 128, 256, 512]
        else:
            n = (n-2) // 6
            return [16]*n+[32]*n+[64]*n
    elif model.startswith('vgg'):
        n = int(model[3:].split('_')[0])
        if n == 9:
            return [64, 128, 256, 512, 512]
        elif n == 11:
            return [64, 128, 256, 512, 512]

    raise NotImplementedError




def main():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--dataroot', required=True, help='Path to the dataset')
    parser.add_argument('--dataset', default='cub200')
    parser.add_argument('--datasplit', default='cub200')
    parser.add_argument('--datanoise', action='store_true', default=False)
    parser.add_argument('--batchSize', type=int, default=64, help='Input batch size')
    parser.add_argument('--workers', type=int, default=4)

    parser.add_argument('--source-model', default='resnet34', type=str)
    parser.add_argument('--source-domain', default='imagenet', type=str)
    parser.add_argument('--source-path', type=str, default=None)
    parser.add_argument('--target-model', default='resnet18', type=str)

    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--lr', type=float, default=0.1,help='Initial learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--wd', type=float, default=0.0001, help='Weight decay')
    parser.add_argument('--nesterov', action='store_true')
    parser.add_argument('--schedule', action='store_true', default=True)
    parser.add_argument('--pairs', type=str, default='4-4,4-3,4-2,4-1,3-4,3-3,3-2,3-1,2-4,2-3,2-2,2-1,1-4,1-3,1-2,1-1')
    parser.add_argument('--numTrain', type=int, default=100, help='Train sample size. 100 means use all')

    parser.add_argument('--meta-lr', type=float, default=0, help='Initial learning rate for meta networks')
    parser.add_argument('--optimizer', type=str, default='sgd')

    parser.add_argument('--experiment', default='logs', help='Where to store models')

    # default settings
    opt = parser.parse_args()

    # Seeds
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #os.makedirs(opt.experiment)
    set_logging_config(opt.experiment)
    logger = logging.getLogger('main')
    logger.info(' '.join(os.sys.argv))
    logger.info(opt)

    # load source model
    if opt.source_domain == 'imagenet':
        from utils import resnet_ilsvrc
        source_model = resnet_ilsvrc.__dict__[opt.source_model](pretrained=True).to(device)
    else:
        opt.model = opt.source_model
        weights = []
        source_gen_params = []
        source_path = os.path.join(
            opt.source_path, '{}-{}'.format(opt.source_domain, opt.source_model),
            '0',
            'model_best.pth.tar'
        )
        ckpt = torch.load(source_path)
        opt.num_classes = ckpt['num_classes']
        source_model = check_model(opt).to(device)
        source_model.load_state_dict(ckpt['state_dict'], strict=False)

    # load dataloaders
    loaders = check_dataset(opt)

    # load target model
    opt.model = opt.target_model
    opt.pretrained_model = False

    target_model = check_model(opt).to(device)

    target_params = list(target_model.parameters())
    if opt.meta_lr == 0:
        target_optimizer = optim.SGD(target_params, lr=opt.lr, momentum=opt.momentum, weight_decay=opt.wd)
    else:
        target_optimizer = MetaSGD(target_params,
                                   [target_model],
                                   lr=opt.lr,
                                   momentum=opt.momentum,
                                   weight_decay=opt.wd, rollback=True, cpu=opt.T>2)

    state = {
        'target_model': target_model.state_dict(),
        'target_optimizer': target_optimizer.state_dict(),
        'best': (0.0, 0.0)
    }

    scheduler = optim.lr_scheduler.CosineAnnealingLR(target_optimizer, opt.epochs)

    def validate(model, loader):
        acc = AverageMeter()
        model.eval()
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            y_pred, _ = model(x)
            acc.update(accuracy(y_pred.data, y, topk=(1,))[0].item(), x.size(0))
        return acc.avg

    def train_objective(data):
        x, y = data[0].to(device), data[1].to(device)
        y_pred, _ = target_model(x)
        state['accuracy'] = accuracy(y_pred.data, y, topk=(1,))[0].item()
        loss = F.cross_entropy(y_pred, y)
        state['loss'] = loss.item()
        return loss

    # source generator training
    state['iter'] = 0
    for epoch in range(opt.epochs):
        if opt.schedule:
            scheduler.step()

        state['epoch'] = epoch
        target_model.train()
        source_model.eval()
        train_acc = AverageMeter()
        for i, data in enumerate(loaders[0]):
            target_optimizer.zero_grad()
            train_objective(data).backward()
            target_optimizer.step()

            train_acc.update(state['accuracy'], data[0].size(0))
            logger.info('[Epoch {:3d}] [Iter {:3d}] [Loss {:.4f}] [Acc {:.4f}]'.format(
                state['epoch'], state['iter'],
                state['loss'], state['accuracy']))
            state['iter'] += 1

        state['accuracy'] = train_acc.avg
        acc = (validate(target_model, loaders[1]),
               validate(target_model, loaders[2]))

        if state['best'][0] < acc[0]:
            state['best'] = acc

        # if state['epoch'] % 10 == 0:
        #     torch.save(state, os.path.join(opt.experiment, 'ckpt-{}.pth'.format(state['epoch']+1)))

        logger.info('             [Epoch {}] [train {:.4f}] [val {:.4f}] [test {:.4f}] [best {:.4f}]'
                    .format(epoch, state['accuracy'], acc[0], acc[1], state['best'][1]))


if __name__ == '__main__':
    main()
