# %%
# %%
import os
import argparse
import sys; sys.path.append("./ANODE") # import hack
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.model import OurModel, Simulator
from models.conv_models import ConvODENet, MNISTConvODENet
from models.mlp_model import ODENet
from omegaconf import OmegaConf
import plotly.express as px
from utils import *
import wandb
from torch.func import vmap, jacrev, jacfwd, functional_call
from trainer import *
from glob import glob

# os.environ['CUDA_VISIBLE_DEVICES']='6,'
torch.set_grad_enabled(False)

# %%
def get_default_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--task', type=str, default='bostonHousing',)
    parser.add_argument('--split_num', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=3e-3)
    parser.add_argument('--lr_scheduler', type=str, default='none', choices=['none', 'cos', 'step'])
    parser.add_argument('--fgh_lr', type=eval, default=None, help='Learning rate for f, g, h. e.g [1e-3, 1e-3, 1e-3]')
    parser.add_argument('--fgh_lr_rel', type=eval, default=None, help='Learning rate for f, g, h. Relative to lr. e.g [1, 1, 1]')
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--total_steps', type=int, default=-1)
    parser.add_argument('--lambdas', type=eval, default=[1., 1., 0., 0.])
    parser.add_argument('--f_jac_clamp', type=eval, default=[-1, -1], help='Clamp jacobian norm to this range, -1 means no clamping')
    parser.add_argument('--g_jac_clamp', type=eval, default=[-1, -1], help='Clamp jacobian norm to this range, -1 means no clamping')
    parser.add_argument('--label_proj_strategy', type=str, default='repeat', choices=['repeat', 'reshape', 'mlp'])
    parser.add_argument('--test_every', type=float, default=5)
    parser.add_argument('--f_sg_target', action='store_true', help='Detach f in velocity target from computational graph')
    parser.add_argument('--latent_chan', type=int, default=64)
    parser.add_argument('--h_dim', type=int, default=0)
    parser.add_argument('--h_add_blocks', type=int, default=4)
    parser.add_argument('--f_add_blocks', type=int, default=0)
    parser.add_argument('--g_add_blocks', type=int, default=0)
    parser.add_argument('--nonlinearity', type=str, default='relu', choices=['relu', 'softplus', 'swish'])
    parser.add_argument('--fixnorm', action='store_true', help='Use fixed architecture not ending with norm')
    parser.add_argument('--no_final_norm', action='store_true', help='Do not use final norm layer at odefunc')
    parser.add_argument('--no_out_norm', action='store_true', help='Do not use starting GroupNorm layer at out_projection')
    parser.add_argument('--t_transform', type=str, default='identity', choices=['identity', 'square', 'one_minus_cos', 'cubic'])
    parser.add_argument('--invert_transform_t', action='store_true', help='transform t in inverse way (use 1-t instead of t)')
    parser.add_argument('--in_proj_type', type=str, default='mlp', choices=['linear', 'identity', 'conv1x1', 'conv3x3',
                                                                                'mlp', 'padding', 'mlp2', 'anode'])
    parser.add_argument('--out_proj_type', type=str, default='linear', choices=['linear', 'mlp', 'mlp2', 'padding'])
    parser.add_argument('--label_proj_type', type=str, default='linear', choices=['linear', 'mlp', 'mlp2', 'padding'])
    parser.add_argument('--conv_mnist', action='store_true', help='Use ConvODENet for MNIST')
    parser.add_argument('--mid_conv', type=int, default=1, help='Number of conv layers in the middle of ConvODENet')
    parser.add_argument('--train_alter', action='store_true', help='Train modules in alternating order')
    parser.add_argument('--train_alter_order', type=str, default='fgh', help='Alternating update order, underscore-sepatated. e.g. fg_h')
    parser.add_argument('--train_alter_epoch', type=str, default='1', help='Alternating update epochs, underscore-separated. e.g. 1_3')
    parser.add_argument('--sync_t', action='store_true', help='Use same t for all instances in a batch')
    parser.add_argument('--augment_t', type=int, default=1)
    parser.add_argument('--label_flow_noise', type=float, default=0., help='Add noise to z1 for flow prediction')
    parser.add_argument('--label_flow_noise_0', type=float, default=0., help='Add noise to z0 for flow prediction')
    parser.add_argument('--t_final', type=float, default=1., help='Train and test with [0, t_final] instead of [0, 1]')
    parser.add_argument('--in_latent_chan', type=int, default=64, help='Input latent channel for OurModel')
    parser.add_argument('--f_act', type=str, default='relu', choices=['relu', 'swish'])
    parser.add_argument('--h_act', type=str, default='relu', choices=['relu', 'swish'])
    parser.add_argument('--dyn_use_norm', action='store_true', help='Use norm in dynamic model')
    parser.add_argument('--dyn_layers', type=int, default=3, help='Number of layers in dynamic model')
    parser.add_argument('--h_dropout', type=float, default=0.0)
    parser.add_argument('--dynamics', type=str, default='linear', choices=['linear', 'cos', 'inv_cos', 'vp_ode', 'lin_cos', 'const_vp_ode', 'learnable', 'half_circle', 'lin_sin'])
    parser.add_argument('--mlp_hidden_dim', type=int, default=64)
    parser.add_argument('--time_modulation', type=str, choices=['none', 'fourier', 'adaln'], default='none')
    parser.add_argument('--adjoint', action='store_true', help='Use Adjoint Sensitivity Method')
    parser.add_argument('--h_norm', action='store_true', help='Use Adjoint Sensitivity Method')
    parser.add_argument('--augment_dim', type=int, default=0)
    parser.add_argument('--steer', type=float, default=0.)
    parser.add_argument('--ema', type=float, default=0., help='Exponential moving average for parameters')
    parser.add_argument('--ke_reg', type=float, default=0.01, help='Regularization for kinetic energy')
    parser.add_argument('--jf_reg', type=float, default=0.01, help='Regularization for jacobian')
    ### not so frequently used...
    parser.add_argument('--in_proj_scale', type=float, default=None)
    parser.add_argument('--label_proj_scale', type=float, default=None)
    parser.add_argument('--proj_norm', type=str, default='none', choices=['none', 'ln', 'bn'])
    parser.add_argument('--force_zero_prob', type=float, default=0.)
    parser.add_argument('--label_ae_noise', type=float, default=0.)
    parser.add_argument('--dataset', type=str, choices=['cifar10', 'mnist', 'uci', 'svhn'], default='uci')
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--weight_decay', type=float, default=0.)
    parser.add_argument('--method', type=str, choices=['ours', 'node'], default='ours')
    parser.add_argument('--label_ae_criterion', type=str, choices=['ce', 'mse'], default='mse')
    parser.add_argument('--task_dec', action='store_true', help='Use task loss only for decoder')
    parser.add_argument('--save_every', type=int, default=24, help='save and evaluate every n hours')
    parser.add_argument('--patience', type=int, default=-1, help='Early stopping patient')
    parser.add_argument('--watch', action='store_true', help='Watch model with wandb')
    return parser.parse_args('')

def adjust_args(row, args):
    args.task = row['task']
    args.split_num = row['split_num']
    args.ckpt_path = glob(f'ckpts/*{row["ID"]}/**/best_val.ckpt', recursive=True)[0]
    args.name = row['Name']
    return args

def setup_from_row(row):
    args = get_default_args()
    args = adjust_args(row, args)

    print(args.ckpt_path)

    # Set default arguments for model architecture
    if args.dataset == 'uci':
        args.in_proj_type = 'mlp'
        args.out_proj_type = 'linear'
        args.label_proj_type = 'linear'
        args.nonlinearity = 'relu'
        args.latent_chan = 64
        args.mlp_hidden_dim = 64
        args.f_add_blocks = 0
        args.h_add_blocks = 4
        args.g_add_blocks = 0
        metric_key = 'rmse'
        ckpt = 'best_val.ckpt'
        
    if args.dataset == 'uci':
        args.mlp_hidden_dim = args.latent_chan
        args.proj_norm = 'none'

    seed = args.seed
    label_proj_strategy = args.label_proj_strategy
    label_ae_mse = True

    fix_random_seeds(seed, strict=True)
    simulator = None
    if args.dataset == 'uci':
        train_dataset = UCI('UCI_Datasets', args.task, args.split_num, 'train', device='cpu')
        # val_dataset = UCI('UCI_Datasets', args.task, args.split_num, 'val')
        test_dataset = UCI('UCI_Datasets', args.task, args.split_num, 'test', device='cpu')
        # train_subset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset))[:100])
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=False, drop_last=True)
        # val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0,
        #                             pin_memory=False,)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0,
                                    pin_memory=False)
        # subset_loader = DataLoader(train_subset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=False, drop_last=False)
        label_scaler = train_dataset.scaler_y
        task_criterion = nn.MSELoss()
    else:
        raise ValueError(f'Dataset {args.dataset} not supported')

    if args.dataset == 'uci':
        data_dim = train_dataset.train_dim_x
        output_dim = train_dataset.train_dim_y
        net = ODENet(device='cuda', data_dim=data_dim, hidden_dim=args.mlp_hidden_dim, output_dim=output_dim,
                        latent_dim=args.latent_chan, augment_dim=args.augment_dim, time_dependent=True,
                    in_proj=args.in_proj_type, out_proj=args.out_proj_type, label_proj=args.label_proj_type, proj_norm=args.proj_norm,
                    in_proj_scale=args.in_proj_scale, label_proj_scale=args.label_proj_scale, t_final=args.t_final,
                    time_modulation=args.time_modulation, non_linearity=args.nonlinearity,
                    h_add_blocks=args.h_add_blocks, f_add_blocks=args.f_add_blocks, g_add_blocks=args.g_add_blocks).cuda()

    m, u =net.load_state_dict(torch.load(args.ckpt_path, 'cpu'), strict=False)
    if len(m) > 0:
        print('Missing keys:')
        print(m)
    if len(u) > 0:
        print('Unexpected keys:')
        print(u)
    return net, test_loader, label_scaler, args

# %%
@torch.inference_mode()
def test_rmse(net, test_dataloader, method='dopri5', num_timesteps=1+1, label_scaler=None):
    net.eval()
    count = 0
    latent_mse = data_mse = 0
    rmse = 0
    for i, (X, Y) in tqdm(enumerate(test_dataloader), leave=False, total=len(test_dataloader), desc='Measure rmse'):
        X, Y = X.cuda(), Y.cuda()
        if method == 'dopri5':
            feat, pred = net(X, return_features=True, method='dopri5')
            traj = [feat]
        else:
            traj, pred = net.get_traj(X, method=method, timesteps=num_timesteps)
        count += Y.size(0)
        Y_unnorm = label_scaler.inverse_transform(Y.cpu().numpy())
        pred_unnorm = label_scaler.inverse_transform(pred.cpu().numpy())
        rmse += np.mean((Y_unnorm - pred_unnorm)**2) * Y.size(0)
    #TODO: compute rmse
    rmse /= count
    rmse = rmse ** 0.5
    return rmse

@torch.inference_mode()
def test(net, name, test_dataloader, label_scaler):
        ret = {}
        metric1 = test_rmse(net, test_dataloader, method='euler', num_timesteps=1+1, label_scaler=label_scaler)
        metric2 = test_rmse(net, test_dataloader, method='euler', num_timesteps=2+1, label_scaler=label_scaler)
        metric10 = test_rmse(net, test_dataloader, method='euler', num_timesteps=10+1, label_scaler=label_scaler)
        metric20 = test_rmse(net, test_dataloader, method='euler', num_timesteps=20+1, label_scaler=label_scaler)
        metric100 = test_rmse(net, test_dataloader, method='euler', num_timesteps=100+1, label_scaler=label_scaler)
        metric1000 = test_rmse(net, test_dataloader, method='euler', num_timesteps=1000+1, label_scaler=label_scaler)
        metricinf = test_rmse(net, test_dataloader, method='dopri5', label_scaler=label_scaler)

        metric_key = 'test_on_best_val'
        metric_type = 'rmse'

        ret[f'{metric_key}/{metric_type}_1'] = metric1
        ret[f'{metric_key}/{metric_type}_2'] = metric2
        ret[f'{metric_key}/{metric_type}_10'] = metric10
        ret[f'{metric_key}/{metric_type}_20'] = metric20
        ret[f'{metric_key}/{metric_type}_100'] = metric100
        ret[f'{metric_key}/{metric_type}_1000'] = metric1000
        ret[f'{metric_key}/{metric_type}_dopri'] = metricinf
        return ret

# %%
@torch.inference_mode()
def evaluate_from_row(row):
    row_idx, row_data = row
    net, test_loader, label_scaler, args = setup_from_row(row_data)
    ret = test(net, row_data['Name'], test_loader, label_scaler)
    ret['ID'] = row_data['ID']
    ret['Name'] = row_data['Name']
    ret['task'] = row_data['task']
    ret['split_num'] = row_data['split_num']
    # save ret as csv
    ret_df = pd.DataFrame(ret, index=[0])
    save_path = os.path.join('ckpts', args.name)
    os.makedirs(save_path, exist_ok=True)
    ret_df.to_csv(f'{save_path}/best_val_log.csv')
    print('Saved to', f'{save_path}/best_val_log.csv')
    return ret_df

# %%
# multiprocess evaluation
from multiprocessing import Pool
from tqdm import tqdm

def evaluate_all(summary_df, eval_name, n_jobs=16):
    with Pool(16) as p:
        rets = list(tqdm(p.imap(evaluate_from_row, summary_df.iterrows()), total=len(summary_df)))
    rets = pd.concat(rets)
    rets.to_csv(f'ckpts/{eval_name}.csv')
    print('Saved to', f'ckpts/{eval_name}.csv')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--eval_name', type=str, default='NODE_all_UCI')
    parser.add_argument('--n_jobs', type=int, default=8)
    args = parser.parse_args()
    exp_csv = pd.read_csv(f'csvs/{args.eval_name}.csv')
    ID = exp_csv['ID']
    task = exp_csv['task']
    name = exp_csv['Name']
    # integer
    split_num = exp_csv['split_num'].to_numpy().astype(int)

    # convert split_num as 0 if split_num < 0
    split_num[split_num < 0] = 0

    # convert NaN task as 'YearPredictionMSD'
    task = task.fillna('YearPredictionMSD')

    summary_df = pd.DataFrame({
        'Name': name,
        'ID': ID,
        'task': task,
        'split_num': split_num
    })
    print(summary_df.head())
    evaluate_all(summary_df, args.eval_name, n_jobs=args.n_jobs)


