import os
import sys

path = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, f'{path}/src/')
print(f'{path}/src/')

import numpy as np
import argparse
from os.path import join
import torch
from torchvision.utils import make_grid
from tqdm import tqdm
import time

from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from generalization_study.test import test_epoch
from generalization_study.datasets import get_dataset, \
    custom_collate
from generalization_study.modify_datasets import \
    modify_dataset, convert_to_pairs_dataset
from generalization_study.utils import DummyClass, RSquared, \
    save_checkpoint, Tracker, get_exp_name
from generalization_study.train_readout_model import train_mlp_on_readout

from generalization_study.models import get_model
# from generalization_study.evaluate_disentanglement import evaluate_dislib
from torch.utils.data import DataLoader
from datetime import datetime
from pytz import timezone

tz = timezone("Europe/Berlin")


def main(args):
    start_time = time.time()

    # paths
    save_path = join(args.project_path, args.save_path, args.name)
    data_path = join(args.project_path, args.data_path)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # data
    dataset, number_factors, number_channels, test_ratio_per_factor = \
        get_dataset(args.dataset, data_path)
    args.test_ratio_per_factor = test_ratio_per_factor  # for logging
    dataset_train, dataset_test, datasets_evaluation = \
        modify_dataset(dataset,
                       args.modification,
                       args.test_ratio_per_factor)

    data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size,
                                   num_workers=args.number_worker, shuffle=True)
    data_loader_test = DataLoader(dataset_test, batch_size=args.batch_size * 10,
                                  num_workers=args.number_worker, shuffle=True)
    data_loader_full = DataLoader(dataset, batch_size=args.batch_size,
                                  num_workers=args.number_worker, shuffle=True)
    data_loader_train_readout = data_loader_train

    if not args.supervised and not args.model == 'betavae':
        # train with pairs of images
        dataset_train_pairs = \
            convert_to_pairs_dataset(dataset_train,
                                     args.transition_prior,
                                     args.max_number_changing_factors,
                                     args.modification)
        data_loader_train = DataLoader(dataset_train_pairs,
                                       batch_size=args.batch_size,
                                       num_workers=args.number_worker,
                                       shuffle=True,
                                       collate_fn=custom_collate)

    # model
    model = get_model(args.model, number_factors, number_channels,
                      args.number_latents, args, dataset).to(device)
    print('model', model)

    # for the transfer learning we might only want to train the last layer
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    if args.only_train_last_layer:
        assert len(parameters) == 2
    optimizer = Adam(parameters, lr=args.learning_rate, weight_decay=args.weight_decay)

    # eval functions
    labels_01 = dataset.get_normalized_labels()
    r_squared = RSquared(labels_01, device)

    # bookkeeping
    if args.writer:
        time_string = datetime.now(tz=tz).strftime("%Y-%m-%d-%H-%M-%S")
        writer = SummaryWriter(save_path + f'/writer_{time_string}.tb')
    else:
        writer = DummyClass()
    for arg, val in args.__dict__.items():
        writer.add_text(arg, str(val), 0)
    variance_per_factor = r_squared.variance_per_factor
    writer.add_text('variances_per_factor', str(variance_per_factor), 0)
    number_samples = dataset.data.shape[0]
    number_test = dataset_test.data.shape[0]
    number_train = dataset_train.data.shape[0]
    test_ratio = number_train / number_samples
    writer.add_text('data/number_samples_total', str(number_samples), 0)
    writer.add_text('data/number_samples_test', str(number_test), 0)
    writer.add_text('data/number_samples', str(number_train), 0)
    writer.add_text('data/ratio_test_total', str(test_ratio), 0)

    # train model
    last = False
    infos = {}
    iteration = -1
    tracker = Tracker(writer)
    pbar = tqdm(total=args.max_number_iterations)
    eval_modulo = int(args.max_number_iterations / (args.number_evals - 1))  # log number_evals times
    eval_dislib_modulo = int(args.max_number_iterations / (4 - 1))
    save_steps_modulo = int(args.max_number_iterations / (3 - 1))
    flush_steps = int(args.max_number_iterations / (51 - 1))  # cheap
    while iteration < args.max_number_iterations - 1:
        pbar.update(iteration - pbar.n)
        for batch, targets in data_loader_train:
            iteration += 1
            if iteration == args.max_number_iterations - 1:
                last = True

            # train
            model.train()
            batch = batch.to(device)
            targets = targets.to(device)

            # get loss
            if args.model == 'pcl':
                latents = model(batch)
                loss, infos = model.loss_f(latents)
            elif args.model == 'slowvae' or args.model == 'betavae':
                x_recon, mu, logvar = model(batch, train=True)
                loss, infos = model.loss_f(batch, x_recon, mu, logvar)
            elif args.model == 'adagvae':
                loss, infos = model(batch, train=True)
            else:
                latents = model(batch)
                squared_diff = (targets - latents).pow(2)
                loss = squared_diff.sum(dim=1).mean()  # mse
                infos['mse_loss'] = loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # bookkeeping
            tracker.track(infos)
            if iteration == 0:
                grid = make_grid(batch[:64], pad_value=1)
                writer.add_image('train/batch', grid, iteration)
            if iteration % flush_steps == 0:
                tracker.write(iteration)

            # fully-supervised evaluations
            if iteration % eval_modulo == 0 or last:
                model.eval()
                if not args.supervised:
                    print('training readout')
                    # rsquared with readout classifier
                    supervised_model = train_mlp_on_readout(
                        model,
                        data_loader_train_readout,
                        number_latents=args.number_latents,
                        number_factors=number_factors,
                        device=device, writer=writer,
                        current_iteration=iteration)
                else:
                    supervised_model = model

                test_epoch(iteration, supervised_model, data_loader_test,
                           writer, device,
                           r_squared)
                test_epoch(iteration, supervised_model, data_loader_full,
                           writer, device,
                           r_squared, mode='full')
                # compute loss with only one factor ood, applicable for
                # inter- and extrapolation
                for name, dataset_i in datasets_evaluation.items():
                    tmp_dataloader = DataLoader(dataset_i,
                                                batch_size=args.batch_size * 10,
                                                num_workers=args.number_worker,
                                                shuffle=True)
                    test_epoch(iteration, supervised_model, tmp_dataloader,
                               writer, device,
                               r_squared, mode=f'eval/{name}')

            # evaluations for super- and unsupervised models
            # if (iteration % eval_dislib_modulo == 0 or last) \
            #         and not args.skip_dislib_eval:
            #     model.eval()
            #     # dislib evaluation
            #     dis_lib_metrics = args.dis_lib_metrics
            #     evaluate_dislib(iteration, model, dataset, 'full', writer,
            #                     dis_lib_metrics,
            #                     last, supervised=args.supervised)
            #     evaluate_dislib(iteration, model, dataset_test, 'test', writer,
            #                     dis_lib_metrics, last,
            #                     supervised=args.supervised)
            #     evaluate_dislib(iteration, model, dataset_train, 'train',
            #                     writer,
            #                     dis_lib_metrics, last,
            #                     supervised=args.supervised)
            #     for name, dataset_i in datasets_evaluation.items():
            #         evaluate_dislib(iteration, model, dataset_i,
            #                         f'test_one_ood/{name}',
            #                         writer, dis_lib_metrics, last,
            #                         supervised=args.supervised)

            # save model
            if iteration % save_steps_modulo == 0 or last:
                if last:
                    name = 'last_epoch.pt'
                else:
                    name = f'checkpoint_{iteration}.pt'
                save_checkpoint(model, optimizer, args, iteration,
                                save_path, name)

            if last:
                print('training done', iteration)
                break

    total_time = (time.time() - start_time) / 60
    writer.add_text('run_time/total_minutes', str(total_time))
    print('time total minutes: ', total_time)


def parse_args(ipynb=False):
    parser = argparse.ArgumentParser(description='')

    parser.add_argument('--name', type=str, default='auto',
                        help='name of the project folder')
    parser.add_argument('--name-suffix', type=str, default='',
                        help='additional info to auto naming')
    parser.add_argument('--project-path', type=str,
                        default='/home/anonymous/lanonymous/src/GeneralizationStudy/src/GeneralizationStudy/',
                        help='project_path')
    parser.add_argument('--save-path', type=str, default='exp/trash/',
                        help='path within project-path')
    parser.add_argument('--data-path', type=str, default='data/',
                        help='path within project-path')
    parser.add_argument('--writer', action='store_true', default=False,
                        help='Whether to use a writer')
    parser.add_argument('--seed', default=1, type=int, help='random seed')

    # data
    parser.add_argument('--dataset', type=str.lower, default='dsprites',
                        help='Dataset to use',
                        choices=['dsprites', 'shapes3d', 'mpi3d', 'smalldummy', 'dsprites_rot90', 'mpi3d_toy',
                                 'celeb_glow'])
    parser.add_argument('--batch-size', type=int, default=64,
                        help='Batch size')
    parser.add_argument('--modification', type=str, default='none',
                        choices=['extrapolation', 'interpolation',
                                 'composition', 'none', 'random'],
                        help='data set modification')
    parser.add_argument('--max-number-changing-factors', type=int, default=1,
                        help='k from locatello paper')
    parser.add_argument('--number-worker', type=int, default=8,
                        help='number of workers')
    # train
    parser.add_argument('--max-number-iterations', type=int, default=500000,
                        help='Number of training iterations')
    parser.add_argument('--number-evals', type=int, default=4,
                        help='Number of training iterations')

    parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4)
    parser.add_argument('--weight-decay', type=float, default=0)

    # model
    parser.add_argument('--model', type=str, default='vanilla',
                        choices=['vanilla', 'deeper_cnn', 'rotation',
                                 'implicit', 'transformer', 'mlp',
                                 'coordconv', 'betavae', 'slowvae', 'pcl',
                                 'adagvae', 'densenet',
                                 'big_transfer_rn50',
                                 'big_transfer_rn101'],
                        help='Which architecture to use')
    parser.add_argument('--number-latents', type=int, default=10,
                        help='Only for non-supervised models')
    # resnet models
    parser.add_argument('--only-train-last-layer', action='store_true',
                        default=False,
                        help='Whether to skip supervised eval')
    parser.add_argument('--pretrained', action='store_true',
                        default=False,
                        help='Whether to use a pretrained transfer net')
    # model - rotation equivariant model
    parser.add_argument('--number-rotations', type=int, default=8,
                        help='Only works for --model = rotation. Number of '
                             'discrete rotations to be equivariant under')
    parser.add_argument('--feature-reduce-factor', type=int, default=8,
                        help='For rotation equvariant CNN reduce feature maps'
                             'factor')
    # vae based models
    parser.add_argument('--vae-beta', type=float, default=1.,
                        help='Weighting factor for the KL[q(z|x)]||p(z) in '
                             'the elbo')
    parser.add_argument('--slowvae-gamma', type=float, default=10.,
                        help='Weighting factor for the Laplacian transition '
                             'prior')
    parser.add_argument('--slowvae-rate', type=float, default=6.,
                        help='Weighting factor for the Laplacian transition '
                             'prior')

    # logging
    parser.add_argument('--steps-supervised-logging', type=int, default=10,
                        help='Number of evalations')
    parser.add_argument('--steps-weakly-logging', type=int, default=4,
                        help='Weighting factor for the KL[q(z|x)]||p(z) in '
                             'the elbo')
    parser.add_argument('--skip-dislib-eval', action='store_true',
                        default=False,
                        help='Whether to skip supervised eval')

    if ipynb:
        args = parser.parse_args(args=ipynb)
    else:
        args = parser.parse_args()

    all_defaults = {key: parser.get_default(key) for key in vars(args)}

    args.dis_lib_metrics = ['dci', 'mig', 'mcc']

    if args.model in ['slowvae', 'pcl', 'betavae', 'adagvae']:
        args.supervised = False
        args.save_path = os.path.join(args.save_path, 'weakly/')
    else:
        # larger lr for supervised models, except densenet
        if not args.model == 'densenet' or args.model == 'rn50'or args.model == 'rn101':
            args.learning_rate = 0.0005
        args.supervised = True
        args.save_path = os.path.join(args.save_path, 'supervised/')

    # each model comes with assumptions about the data generative process
    if args.model in ['slowvae', 'pcl']:
        args.transition_prior = 'laplace'
    elif args.model == 'adagvae':
        args.transition_prior = 'locatello'
    else:
        args.transition_prior = None

    if args.name == 'auto':
        args.name = f'{args.dataset}_{args.modification}_{args.model}'
        if args.feature_reduce_factor != 8:
            args.name = args.name + f'_frf{args.feature_reduce_factor}'
        if args.model == 'slowvae':
            args.name = args.name + f'_gamma{args.slowvae_gamma:0.0f}'
        elif args.model == 'betavae' or args.model == 'adagvae':
            args.name = args.name + f'_beta{args.vae_beta:0.0f}'
        args.name = args.name + f'_seed{args.seed:02d}'

    if args.name_suffix != '':
        args.name = args.name + '_' + args.name_suffix

    if args.seed != -1:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

    args.name = get_exp_name(all_defaults, args)
    # os.makedirs(join(args.project_path, args.save_path, args.name))
    print('name', args.name)
    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)
