# %%
import os
import argparse
import sys; sys.path.append("./ANODE") # import hack
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.model import OurModel, Simulator
from models.conv_models import ConvODENet, MNISTConvODENet
from models.mlp_model import ODENet
from omegaconf import OmegaConf
import plotly.express as px
from utils import *
import wandb
from torch.func import vmap, jacrev, jacfwd, functional_call
from trainer import *


if __name__ == '__main__':
    # argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str, required=True)
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--task', type=str, default='bostonHousing',)
    parser.add_argument('--split_num', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--lr_scheduler', type=str, default='none', choices=['none', 'cos', 'step'])
    parser.add_argument('--fgh_lr', type=eval, default=None, help='Learning rate for f, g, h. e.g [1e-3, 1e-3, 1e-3]')
    parser.add_argument('--fgh_lr_rel', type=eval, default=None, help='Learning rate for f, g, h. Relative to lr. e.g [1, 1, 1]')
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--total_steps', type=int, default=-1)
    parser.add_argument('--lambdas', type=eval, default=[1., 1., 0., 0.])
    parser.add_argument('--f_jac_clamp', type=eval, default=[-1, -1], help='Clamp jacobian norm to this range, -1 means no clamping')
    parser.add_argument('--g_jac_clamp', type=eval, default=[-1, -1], help='Clamp jacobian norm to this range, -1 means no clamping')
    parser.add_argument('--label_proj_strategy', type=str, default='repeat', choices=['repeat', 'reshape', 'mlp'])
    parser.add_argument('--test_every', type=float, default=5)
    parser.add_argument('--f_sg_target', action='store_true', help='Detach f in velocity target from computational graph')
    parser.add_argument('--latent_chan', type=int, default=64)
    parser.add_argument('--h_dim', type=int, default=0)
    parser.add_argument('--h_add_blocks', type=int, default=0)
    parser.add_argument('--f_add_blocks', type=int, default=0)
    parser.add_argument('--g_add_blocks', type=int, default=0)
    parser.add_argument('--nonlinearity', type=str, default='relu', choices=['relu', 'softplus', 'swish'])
    parser.add_argument('--fixnorm', action='store_true', help='Use fixed architecture not ending with norm')
    parser.add_argument('--no_final_norm', action='store_true', help='Do not use final norm layer at odefunc')
    parser.add_argument('--no_out_norm', action='store_true', help='Do not use starting GroupNorm layer at out_projection')
    parser.add_argument('--t_transform', type=str, default='identity', choices=['identity', 'square', 'one_minus_cos', 'cubic'])
    parser.add_argument('--invert_transform_t', action='store_true', help='transform t in inverse way (use 1-t instead of t)')
    parser.add_argument('--in_proj_type', type=str, default='linear', choices=['linear', 'identity', 'conv1x1', 'conv3x3',
                                                                               'mlp', 'padding', 'mlp2', 'anode'])
    parser.add_argument('--out_proj_type', type=str, default='linear', choices=['linear', 'mlp', 'mlp2', 'padding'])
    parser.add_argument('--label_proj_type', type=str, default='linear', choices=['linear', 'mlp', 'mlp2', 'padding'])
    parser.add_argument('--conv_mnist', action='store_true', help='Use ConvODENet for MNIST')
    parser.add_argument('--mid_conv', type=int, default=1, help='Number of conv layers in the middle of ConvODENet')
    parser.add_argument('--train_alter', action='store_true', help='Train modules in alternating order')
    parser.add_argument('--train_alter_order', type=str, default='fgh', help='Alternating update order, underscore-sepatated. e.g. fg_h')
    parser.add_argument('--train_alter_epoch', type=str, default='1', help='Alternating update epochs, underscore-separated. e.g. 1_3')
    parser.add_argument('--sync_t', action='store_true', help='Use same t for all instances in a batch')
    parser.add_argument('--augment_t', type=int, default=1)
    parser.add_argument('--label_flow_noise', type=float, default=0., help='Add noise to z1 for flow prediction')
    parser.add_argument('--label_flow_noise_0', type=float, default=0., help='Add noise to z0 for flow prediction')
    parser.add_argument('--t_final', type=float, default=1., help='Train and test with [0, t_final] instead of [0, 1]')
    parser.add_argument('--in_latent_chan', type=int, default=64, help='Input latent channel for OurModel')
    parser.add_argument('--f_act', type=str, default='relu', choices=['relu', 'swish'])
    parser.add_argument('--h_act', type=str, default='relu', choices=['relu', 'swish'])
    parser.add_argument('--dyn_use_norm', action='store_true', help='Use norm in dynamic model')
    parser.add_argument('--dyn_layers', type=int, default=3, help='Number of layers in dynamic model')
    parser.add_argument('--h_dropout', type=float, default=0.0)
    parser.add_argument('--dynamics', type=str, default='linear', choices=['linear', 'cos', 'inv_cos', 'vp_ode', 'lin_cos', 'const_vp_ode', 'learnable', 'half_circle', 'lin_sin'])
    parser.add_argument('--mlp_hidden_dim', type=int, default=512)
    parser.add_argument('--time_modulation', type=str, choices=['none', 'fourier', 'adaln'], default='none')
    parser.add_argument('--adjoint', action='store_true', help='Use Adjoint Sensitivity Method')
    parser.add_argument('--h_norm', action='store_true', help='Use Adjoint Sensitivity Method')
    parser.add_argument('--augment_dim', type=int, default=0)
    parser.add_argument('--steer', type=float, default=0.)
    parser.add_argument('--ema', type=float, default=0., help='Exponential moving average for parameters')
    parser.add_argument('--ke_reg', type=float, default=0.01, help='Regularization for kinetic energy')
    parser.add_argument('--jf_reg', type=float, default=0.01, help='Regularization for jacobian')
    ### not so frequently used...
    parser.add_argument('--in_proj_scale', type=float, default=None)
    parser.add_argument('--label_proj_scale', type=float, default=None)
    parser.add_argument('--proj_norm', type=str, default='none', choices=['none', 'ln', 'bn'])
    parser.add_argument('--debug', action='store_true', help='debug mode with 10 instance val set')
    parser.add_argument('--force_zero_prob', type=float, default=0.)
    parser.add_argument('--anode_net', action='store_true', help='Use ConvODENet used in ANODE')
    parser.add_argument('--aug_dim', type=int, default=1)
    parser.add_argument('--label_ae_noise', type=float, default=0.)
    parser.add_argument('--dataset', type=str, choices=['cifar10', 'mnist', 'uci', 'svhn'], default='cifar10')
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--weight_decay', type=float, default=0.)
    parser.add_argument('--method', type=str, choices=['ours', 'node'], default='ours')
    parser.add_argument('--label_ae_criterion', type=str, choices=['ce', 'mse'], default='mse')
    parser.add_argument('--task_dec', action='store_true', help='Use task loss only for decoder')
    parser.add_argument('--save_every', type=int, default=24, help='save and evaluate every n hours')
    parser.add_argument('--patience', type=int, default=-1, help='Early stopping patient')
    parser.add_argument('--watch', action='store_true', help='Watch model with wandb')

    ### deprecated
    # parser.add_argument('--train_fg_epoch', type=int, default=5)
    # parser.add_argument('--train_h_epoch', type=int, default=5)
    # parser.add_argument('--train_fg_h_epoch', type=eval, default=(-1, -1), help='Same as train_fg_epoch and train_h_epoch but set with one argument')
    args = parser.parse_args()

    if args.dataset == 'cifar10':
        wandb.init(project="SFNO", name=args.name, entity='denoising_representation')
    elif args.dataset == 'mnist':
        wandb.init(project="SFNO_MNIST", name=args.name, entity='denoising_representation')
    elif args.dataset == 'uci':
        wandb.init(project="SFNO_UCI_new", name=args.name, entity='denoising_representation')
    elif args.dataset == 'svhn':
        wandb.init(project="SFNO_SVHN", name=args.name, entity='denoising_representation')
    else:
        raise ValueError(f'Dataset {args.dataset} not supported')

    args.train_alter_order = args.train_alter_order.split('_') # f_gh -> ['f', 'gh'], fgh -> ['fgh']
    args.train_alter_epoch = [int(e) for e in args.train_alter_epoch.split('_')]
    print('Alternating:', args.train_alter_order, args.train_alter_epoch)
    assert len(args.train_alter_order) == len(args.train_alter_epoch), 'train_alter_order and train_alter_epoch should have same length'
        
    if args.fgh_lr_rel is not None:
        assert args.fgh_lr is None, 'fgh_lr and fgh_lr_rel cannot be set at the same time'
        args.fgh_lr = [args.lr * lr_rel for lr_rel in args.fgh_lr_rel]
    
    if args.dataset == 'uci':
        args.mlp_hidden_dim = args.latent_chan
        args.proj_norm = 'none'
    if args.dataset == 'mnist':
        args.proj_norm = 'bn'

    # configs
    config = OmegaConf.create(vars(args))
    OmegaConf.save(config, os.path.join(wandb.run.dir, 'exp_config.yaml'))
    wandb.config.update(OmegaConf.to_container(config))

    seed = args.seed
    optim = args.optimizer
    lr = args.lr
    force_zero_prob = args.force_zero_prob
    weight_decay = args.weight_decay
    lambdas = args.lambdas
    label_proj_strategy = args.label_proj_strategy
    assert not args.task_dec or lambdas[2] > 0, 'task_dec requires task loss'

    if args.label_ae_criterion == 'ce':
        label_ae_mse = False
    else:
        label_ae_mse = True

    fix_random_seeds(seed, strict=True)
    simulator = None
    if args.debug:
        args.batch_size = 100
        args.epochs = 10000
        args.test_every = 50
    
    if args.dataset == 'cifar10':
        train_loader, test_loader, val_loader = get_cifar_loaders(data_aug=True, onehot=True, download=False, 
                                                                  batch_size=args.batch_size, nw=8, debug=args.debug)
        subset_loader = None
        label_scaler = None
        task_criterion = nn.CrossEntropyLoss()
    elif args.dataset == 'mnist':
        train_loader, test_loader, val_loader = get_mnist_loaders(data_aug=True, onehot=True, download=False,
                                                                  batch_size=args.batch_size, nw=8)
        subset_loader = None
        label_scaler = None
        task_criterion = nn.CrossEntropyLoss()
    elif args.dataset == 'svhn':
        train_loader, test_loader, val_loader = get_svhn_loaders(data_aug=True, onehot=True, download=False,
                                                                  batch_size=args.batch_size, nw=8)
        subset_loader = None
        label_scaler = None
        task_criterion = nn.CrossEntropyLoss()
    elif args.dataset == 'uci':
        train_dataset = UCI('UCI_Datasets', args.task, args.split_num, 'train')
        val_dataset = UCI('UCI_Datasets', args.task, args.split_num, 'val')
        test_dataset = UCI('UCI_Datasets', args.task, args.split_num, 'test')
        train_subset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset))[:100])
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=False, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0,
                                  pin_memory=False,)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0,
                                  pin_memory=False)
        subset_loader = DataLoader(train_subset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=False, drop_last=False)
        label_scaler = train_dataset.scaler_y
        task_criterion = nn.MSELoss()
    else:
        raise ValueError(f'Dataset {args.dataset} not supported')
    
    if args.dataset == 'mnist':
        if args.conv_mnist:
            net = OurModel(label_proj_strategy=label_proj_strategy, emb_res=(6, 6), latent_chan=args.latent_chan, norm_fix=False,
                           final_norm=(not args.no_final_norm), in_proj_scale=None, label_proj_scale=None, channel_in=1,
                           proj_norm=args.proj_norm, out_norm=(not args.no_out_norm), t_final=args.t_final,
                           f_act=args.f_act, h_act=args.h_act, h_dropout=args.h_dropout,
                           h_dim=args.h_dim, h_add_blocks=args.h_add_blocks, f_add_blocks=args.f_add_blocks, g_add_blocks=args.g_add_blocks,
                           adjoint=args.adjoint, augment_dim=args.augment_dim).cuda()
        else:
            net = ODENet(device='cuda', data_dim=28*28, hidden_dim=args.mlp_hidden_dim, output_dim=10, augment_dim=args.augment_dim, time_dependent=True,
                        in_proj=args.in_proj_type, out_proj=args.out_proj_type, label_proj=args.label_proj_type, proj_norm=args.proj_norm,
                        in_proj_scale=args.in_proj_scale, label_proj_scale=args.label_proj_scale, t_final=args.t_final,
                        time_modulation=args.time_modulation, adjoint=args.adjoint).cuda()
            
    elif args.anode_net:
        assert 0
    
    elif args.dataset == 'uci':
        data_dim = train_dataset.train_dim_x
        output_dim = train_dataset.train_dim_y
        net = ODENet(device='cuda', data_dim=data_dim, hidden_dim=args.mlp_hidden_dim, output_dim=output_dim,
                     latent_dim=args.latent_chan, augment_dim=args.augment_dim, time_dependent=True,
                    in_proj=args.in_proj_type, out_proj=args.out_proj_type, label_proj=args.label_proj_type, proj_norm=args.proj_norm,
                    in_proj_scale=args.in_proj_scale, label_proj_scale=args.label_proj_scale, t_final=args.t_final,
                    time_modulation=args.time_modulation, non_linearity=args.nonlinearity,
                    h_add_blocks=args.h_add_blocks, f_add_blocks=args.f_add_blocks, g_add_blocks=args.g_add_blocks).cuda()
    elif args.dataset in ['cifar10', 'svhn']:
        net = OurModel(label_proj_strategy=label_proj_strategy, latent_chan=args.latent_chan, norm_fix=args.fixnorm,
                       final_norm=(not args.no_final_norm), in_proj_scale=args.in_proj_scale, label_proj_scale=args.label_proj_scale,
                       proj_norm=args.proj_norm, out_norm=(not args.no_out_norm), t_final=args.t_final, 
                       in_latent_chan=args.in_latent_chan, f_act=args.f_act, h_act=args.h_act, h_dropout=args.h_dropout,
                       h_dim=args.h_dim, h_add_blocks=args.h_add_blocks, f_add_blocks=args.f_add_blocks,
                       g_add_blocks=args.g_add_blocks,
                       adjoint=args.adjoint, augment_dim=args.augment_dim, h_norm=args.h_norm).cuda()
    else:
        assert 0
    
    print(net)
    
    # count parameters
    n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f'Number of parameters: {n_params}')

    if args.total_steps == -1:
        args.total_steps = args.epochs * len(train_loader)
        args.test_every = int(len(train_loader) * args.test_every)
    elif args.test_every < 1.0:
        args.test_every = int(args.total_steps * args.test_every)
    
    if args.watch:
        wandb.watch(net, log='all', log_freq=50, idx=0)
    trainer = Trainer(net, optimizer=optim, scheduler=args.lr_scheduler, lr=lr, total_steps=args.total_steps, force_zero_prob=force_zero_prob, lambdas=lambdas, wd=weight_decay,
          test_every=args.test_every, task_criterion=task_criterion, label_ae_noise=args.label_ae_noise, method=args.method, f_sg_target=args.f_sg_target, augment_t=args.augment_t,
          t_transform=args.t_transform, label_ae_mse=label_ae_mse, task_dec=args.task_dec, f_jac_clamp=args.f_jac_clamp,
          g_jac_clamp=args.g_jac_clamp, train_alter=args.train_alter, alter_order=args.train_alter_order, alter_epoch=args.train_alter_epoch,
          fgh_lr=args.fgh_lr, sync_t=args.sync_t, label_flow_noise=args.label_flow_noise,
          label_flow_noise_0=args.label_flow_noise_0, invert_transform_t=args.invert_transform_t, dynamics=eval(args.dynamics),
          label_scaler=label_scaler, dataset=args.dataset, steer=args.steer,
          save_every=args.save_every, patience=args.patience, ema=args.ema, ke_reg=args.ke_reg, jf_reg=args.jf_reg)
    
    net = trainer.fit(train_loader, val_loader, test_loader, subset_loader)
    wandb.finish()
    print(f'Done.')