import os, sys, time, warnings
import argparse
import random
import numpy as np


import torch
import torch.nn as nn
import torchvision

from dataset import MnistRotated, MnistRotated_CF
from solver import Solver

import wandb

model_name = {'auto_full': 'full', 'auto_spa_can': 'cano', 'auto_soft_can': 'soft-cano', 'auto_spa': 'relax-cano'}

KLD_SCHEDULER_PRESETS = {
    1: {'kld_scheduler': 'linear', 'lamb_kld_start': 0.001, 'lamb_kld_end': 1e5,
         'kld_step_delay_frac': 0.2, 'kld_step_end_frac': 1},

    2: {'kld_scheduler': 'linear', 'lamb_kld_start': 100, 'lamb_kld_end': 1e5,
         'kld_step_delay_frac': 0.5, 'kld_step_end_frac': 1},

    3: {'kld_scheduler': 'linear', 'lamb_kld_start': 1, 'lamb_kld_end': 1e-5,
         'kld_step_delay_frac': 0.3, 'kld_step_end_frac': 1},

    4: {'kld_scheduler': 'linear', 'lamb_kld_start': 1, 'lamb_kld_end': 1e10,
         'kld_step_delay_frac': 0, 'kld_step_end_frac': 1},

    5: {'kld_scheduler': 'linear', 'lamb_kld_start': 0.01, 'lamb_kld_end': 1e4,
         'kld_step_delay_frac': 0, 'kld_step_end_frac': 1},

    6: {'kld_scheduler': 'linear', 'lamb_kld_start': 0.01, 'lamb_kld_end': 1e4,
         'kld_step_delay_frac': 0.1, 'kld_step_end_frac': 1},
}

MODEL_PRESETS = {
    1: {'f_type': 'auto_spa', 'g_type': 'beta'},
    2: {'f_type': 'auto_spa_can', 'g_type': 'beta'},
    3: {'f_type': 'auto_full', 'g_type': 'beta'},
    4: {'f_type': 'auto_full', 'g_type': 'independent'},
}

MODEL_PRESET_TO_NAME = {
    1: 'relax-cano',
    2: 'cano',
    3: 'dense',
    4: 'independent',
}



def main():

    parser = argparse.ArgumentParser(description='Train CausalSpa')

    # basics
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed (default: 0)')  # set to None for a random seed

    # data
    parser.add_argument('--list_train_domains', '--list', nargs='+', default=['0', '15', '30', '45', '60'],
                       help='domains used during training')
    parser.add_argument('--mnist_subset', type=str, default='med')
    parser.add_argument('--label_list',type=list,default=list(range(10)))
    parser.add_argument('--image_shape',type=list,default=[1,28,28])

    # training params
    parser.add_argument('--lr_g', type=float, default=1e-4)
    parser.add_argument('--lr_f', type=float, default=0.001)
    parser.add_argument('--kld_scheduler', type=str, default='none',  # if none, lamb_kld := lamb_kld_start
                        choices=['none', 'linear'])
    parser.add_argument('--lamb_kld_start', type=float, default=4)  # this is the kld lambda at the start of training
    parser.add_argument('--lamb_kld_end', type=float, default=4)
    parser.add_argument('--kld_step_delay_frac', type=int, default=0.2)  # the percentage of total steps before scheduler starts
    parser.add_argument('--kld_step_end_frac', type=int, default=1)  # the percentage of total steps when the scheduler should end)
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--num_iters', type=int, default=20_000)
    # parser.add_argument('--spa_iters', default=None)  # can be either an int or set to False/None to turn off sparsity
    # sweep stuff
    parser.add_argument('--kld_preset_value', type=int, default=-1)
    parser.add_argument('--model_preset', type=int, default=-1)

    # algorithm

    # model params
    parser.add_argument('--f_type', type=str, default='auto_spa_can',
                        choices=['mlp', 'auto_full', 'auto_spa', 'auto_spa_can', 'auto_soft_can'])
    parser.add_argument('--g_type', type=str, default='beta', help='type of generator', 
                        choices=['deep', 'simple', 'mlp', 'beta', 'independent'])
    parser.add_argument('--inter_dim', type=int, default=25, help='dimension of intermediate layer of models')
    parser.add_argument('--latent_dim', type=int, default=10)  # size of latent space (i.e. dim of z)
    parser.add_argument('--k_spa', type=int, default=2)

    # log and dir
    parser.add_argument('--data_dir', type=str, default='../data')
    parser.add_argument('--save_dir', default='./saved')
    parser.add_argument('--no_wandb', action='store_true', default=False)
    parser.add_argument('--step_validate', type=int, default=500)  # check the validation loss every this many steps
    parser.add_argument('--step_vis', type=int, default=5_000)  # plot rmnist generated samples and counterfactuals
    parser.add_argument('--step_save', type=int, default=5_00000)  # save model every this many steps
    parser.add_argument('--note', default='')
    parser.add_argument('--project_prefix', default='')

    # sweep (tells us whether this is being called as part of WandB sweep or not)
    parser.add_argument('--sweep', action='store_true', default=False)

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    seed = args.seed
    if str(seed).lower() in ['none', 'false', 'rand', 'random']:
        seed = np.random.randint(1000)
        args.seed = seed
    else:
        args.seed = int(seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    kwargs = {'num_workers': 1, 'pin_memory': False} if args.cuda else {}
    #TODO: check if num workers should be increased.


    # ======================== #
    #         misc check       #
    # ======================== #
    assert args.f_type != 'mlp', 'mlp is not supported yet (since it does not have an true inverse)'
    if args.g_type == 'independent':
        assert args.f_type == 'auto_full', 'independent generator only works with full f type'

    # ======================== #
    #         params           #
    # ======================== #
    args.lamb = dict()
    args.n_domains = len(args.list_train_domains)
    args.no_f_mlp = True

    if args.kld_preset_value > 0:
        kld_preset = KLD_SCHEDULER_PRESETS[args.kld_preset_value]
        print(f'Using KLD scheduler preset {kld_preset}')
        for k, v in kld_preset.items():
            setattr(args, k, v)  # set the values for the kld scheduler
    if args.model_preset > 0:
        model_preset = MODEL_PRESETS[args.model_preset]
        print(f'Using model preset {model_preset}')
        for k, v in model_preset.items():
            setattr(args, k, v)  # set the values for the model params

    # ======================== #
    #         log              #
    # ======================== #
    args.wandb = not args.no_wandb
    if args.model_preset > 0:
        args.run_name = f'[{MODEL_PRESET_TO_NAME[args.model_preset]}]_{args.note}_latent_{args.latent_dim}_seed_{args.seed}_kspa_{args.k_spa}'
    else:
        args.run_name = f'[{model_name[args.f_type]}]_{args.note}_latent_{args.latent_dim}_seed_{args.seed}_kspa_{args.k_spa}'

    # print(f'Run: {args.run_name} with args:\n{vars(args)}')
    if args.wandb or args.sweep:
        wandb.init(project=f"[rmnist" if not args.sweep else None,
                   entity="" if not args.sweep else None,
                   name=args.run_name,
                   config=vars(args))
        wandb.run.log_code()
        wandb.define_metric('Test/CF-total', summary='min')
        # wandb.define_metric('Val-loss/recon', summary='min')
        # wandb.define_metric('Val-loss/align', summary='min')


    # ======================== #
    #         data             #
    # ======================== #
    print('Loading data...')
    all_data = MnistRotated(args.list_train_domains, args.data_dir, train=True, mnist_subset=args.mnist_subset)
    train_size = int(0.9 * len(all_data))
    val_size = len(all_data) -  train_size
    train_set, val_set= torch.utils.data.random_split(all_data, [train_size, val_size],
                                                      generator=torch.Generator().manual_seed(args.seed))
    train_loader = torch.utils.data.DataLoader(train_set,
                                         batch_size=args.batch_size,
                                         shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_set,
                                       batch_size=args.batch_size,
                                       shuffle=False, **kwargs)
    test_set = MnistRotated_CF(args.list_train_domains, args.data_dir)
    test_loader = torch.utils.data.DataLoader(test_set,
                                        batch_size=args.batch_size,
                                        shuffle=False, **kwargs)

   # TODO: check if there are any problems with splitting the data this way.
    
    args.domain_list = list(range(len(args.list_train_domains)))
    print(f'Dataset sizes:\n\ttrain:{all_data.__len__()}, val:{val_set.__len__()}, test:{test_set.__len__()}')

    # ======================== #
    #         training         #
    # ======================== #
    solver = Solver(train_loader, val_loader, test_loader, args)
    solver.train()



if __name__ == "__main__":
    main()