import argparse
import json
import os
import numpy as np
import logging

import utilities
from perturbation_learning import cvae, perturbations, datasets

import torch
from torch import optim
from torchvision.utils import save_image

optimizers = {
    "adam": optim.Adam,
    "sgd": optim.SGD
}

def save_chkpt(model, optimizer, epoch, test_loss, name, dp):
    if dp:
        model.undataparallel()
    torch.save({
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "test_loss": test_loss
    }, name)
    if dp:
        model.dataparallel()

def train(config, output_dir):
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(os.path.join(output_dir,'output.log')),
            logging.StreamHandler()
        ])

    model = cvae.models[config.model.type](config)
    model.to(config.device)

    h_train = perturbations.hs[config.perturbation.train_type](config.perturbation)
    h_test = perturbations.hs[config.perturbation.test_type](config.perturbation)
    train_loader, test_loader = datasets.loaders[config.dataset.type](config)

    optimizer = optimizers[config.training.optimizer](model.parameters(),
                    lr=1, weight_decay=config.training.weight_decay)
                    #momentum=config.training.momentum)

    lr_schedule = lambda t: np.interp([t], *config.training.step_size_schedule)[0]
    best_test_loss = 1e7

    start_epoch = 0
    if config.resume is not None:
        d = torch.load(config.resume)
        logger.info(f"Resume model checkpoint {d['epoch']}...")
        optimizer.load_state_dict(d["optimizer_state_dict"])
        model.load_state_dict(d["model_state_dict"])
        start_epoch = d["epoch"] + 1

        try:
            d = torch.load(os.path.join(output_dir, 'checkpoints', 'checkpoint_best.pth'))
            best_test_loss = d["test_loss"]
        except:
            logger.info("No best checkpoint to resume test loss from")

    if config.dataparallel:
        model.dataparallel()

    for epoch in range(start_epoch, config.training.epochs):
        # Training
        model.train()
        train_loss = 0
        for batch_idx, batch in enumerate(train_loader):
            data = batch[0]
            epoch_idx = epoch + (batch_idx + 1) / len(train_loader)
            lr = lr_schedule(epoch_idx)
            optimizer.param_groups[0].update(lr=lr)

            hdata = h_train(batch)
            data = data.to(config.device)
            hdata = hdata.to(config.device)

            model.pregradient(batch)

            optimizer.zero_grad()
            output = model(data)
            loss = cvae.glo_loss(hdata, output,
                                 distribution=config.model.output_distribution)
            loss.backward()
            optimizer.step()

            model.postgradient(batch)

            train_loss += loss.item()
            if batch_idx % config.training.log_interval == 0:
                logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(data)))

        logger.info('====> Epoch: {} Average loss: {:.4f} lr {:.8f}'.format(
              epoch, train_loss / len(train_loader.dataset), lr))

        # Testing
        if (epoch+1) % config.eval.test_interval == 0:
            model.eval()
            test_loss = 0
            with torch.no_grad():
                for i, batch in enumerate(test_loader):
                    data = batch[0]
                    hdata = h_test(batch)
                    data = data.to(config.device)
                    hdata = hdata.to(config.device)

                    model.preevaluate(data, hdata)

                    output = model(data)

                    loss = cvae.glo_loss(hdata, output,
                                         distribution=config.model.output_distribution)
                    test_loss += loss.item()
                    if i == 0 and (epoch+1) % config.eval.sample_interval == 0:
                        # take maximum estimate as reconstruction
                        if config.model.output_distribution == 'softmax':
                            output = output.max(1)[0]

                        n = min(data.size(0), 8)
                        hcomparison = torch.cat([
                                                data[:n],
                                                hdata[:n],
                                                output.view(*hdata.size())[:n]])
                        save_image(hcomparison.cpu(),
                                 os.path.join(output_dir, 'images', f'hreconstruction_{epoch}.png'), nrow=n)

                        hsample = model.sample(data)
                        save_image(hsample[:min(64,config.eval.batch_size)],
                                   os.path.join(output_dir, 'images', f'hsample_{epoch}.png'))

                        repeat_hsample = torch.cat([model.sample(data)[:8].unsqueeze(1) for i in range(8)],dim=1)
                        repeat_hsample = repeat_hsample.view(-1,*hdata.size()[1:])
                        save_image(repeat_hsample[:min(64,config.eval.batch_size)],
                                   os.path.join(output_dir, 'images', f'repeat_hsample_{epoch}.png'))

                if (epoch+1) % config.training.checkpoint_interval == 0:
                    save_chkpt(model, optimizer, epoch, test_loss,
                               os.path.join(output_dir, 'checkpoints', f'checkpoint_{epoch}.pth'),
                               config.dataparallel)

                if test_loss < best_test_loss:
                    save_chkpt(model, optimizer, epoch, test_loss,
                               os.path.join(output_dir, 'checkpoints', 'checkpoint_best.pth'),
                               config.dataparallel)
                    best_test_loss = test_loss

                save_chkpt(model, optimizer, epoch, test_loss,
                           os.path.join(output_dir, 'checkpoints', 'checkpoint_latest.pth'),
                           config.dataparallel)


            test_loss /= len(test_loader.dataset)
            logger.info('====> Test set loss: {:.4f}'.format(test_loss))



if __name__ == "__main__":
    parser = argparse.ArgumentParser(
                        description='Train script options',
                        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-c', '--config', type=str,
                        help='path to config file',
                        default='config.json', required=False)
    parser.add_argument('-dp', '--dataparallel',
                        help='data paralllel flag', action='store_true')
    parser.add_argument('--resume', default=None, help='path to checkpoint')
    args = parser.parse_args()
    config_dict = utilities.get_config(args.config)
    config_dict['dataparallel'] = args.dataparallel

    assert os.path.splitext(os.path.basename(args.config))[0] == config_dict['model']['model_dir']

    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    output_dir = os.path.join(config_dict['output_dir'],
                              config_dict['model']['model_dir'])

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for s in ['images', 'checkpoints']:
        extra_dir = os.path.join(output_dir,s)
        if not os.path.exists(extra_dir):
            os.makedirs(extra_dir)

    # keep the configuration file with the model for reproducibility
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config_dict, f, sort_keys=True, indent=4)

    config_dict['resume'] = args.resume

    config = utilities.config_to_namedtuple(config_dict)
    train(config, output_dir)
