import os
import argparse
import sys; sys.path.append("./ANODE") # import hack
from tqdm import tqdm
import torch
import torch.nn.functional as F
from models.model import OurModel, Simulator
from models.mlp_model import ODENet
from omegaconf import OmegaConf
import plotly.express as px
from utils import *
import wandb
from torch.func import vmap, jacrev, jacfwd, functional_call

def learnable(z0, z1, t):
    return None

def linear(z0, z1, t):
    return (1-t) * z0 + t * z1

def half_circle(z0, z1, t):
    ori_shape = z0.shape
    if len(z0.shape) != 2:
        if len(z0.shape) > 2:
            z0 = z0.view(z0.shape[0], -1)
        if len(z0.shape) < 2:
            z0 = z0.view(1, -1)
    if len(z1.shape) != 2:
        if len(z1.shape) > 2:
            z1 = z1.view(z1.shape[0], -1)
        if len(z1.shape) < 2:
            z1 = z1.view(1, -1)
    if len(t.shape) != 2:
        if len(t.shape) > 2:
            t = t.view(t.shape[0], -1)
        if len(t.shape) < 2:
            t = t.view(1, -1)

    degree = t * np.pi
    origin = (z0 + z1) / 2
    radius = (z0 - z1) / 2
    radius_normalized = radius / torch.norm(radius, dim=-1, keepdim=True)
    # choose any vector
    vec = (z0 + z1) / 2

    # Calculate the projection of vec onto target
    ortho = (vec - (vec * radius_normalized).sum(dim=-1, keepdim=True) * radius_normalized)
    radius_ortho = ortho / torch.norm(ortho, dim=-1, keepdim=True) * torch.norm(radius, dim=-1, keepdim=True) # make it same length as radius
    ret = origin + radius * torch.cos(degree) + radius_ortho * torch.sin(degree)
    return ret.reshape(ori_shape)

def cos(z0, z1, t):
    degree = t * np.pi / 2.
    return torch.cos(degree) * z0 + torch.sin(degree) * z1

def sin(z0, z1, t):
    degree = t * np.pi / 2.
    return (1 - torch.sin(degree)) * z0 + (1-torch.cos(degree)) * z1

def vp_ode(z0, z1, t):
    alpha = torch.exp(-0.25 * 19.9 * (1-t)**2 - 0.5 * 0.1 * (1-t))
    beta = torch.sqrt(1-alpha**2)
    return alpha * z1 + beta * z0

def const_vp_ode(z0, z1, t):
    alpha = t
    beta = torch.sqrt(1-alpha**2)
    return alpha * z1 + beta * z0

def train(net, train_dataloader, total_epoch, optimizer='adam', scheduler='none', lr=1e-3, wd=[0., 0., 0.], loss_start_epoch=[0, 0, 0], lambdas=[1.0, 1.0, 0.0],
          task_criterion=torch.nn.MSELoss(), force_zero_prob=0., val_dataloader=None, test_every=2, test_dataloader=None, 
          label_ae_noise=0., method='ours', f_sg_target=False, augment_t=1, t_transform='identity', label_ae_mse=True,
          task_dec=False, f_jac_clamp=(-1, -1), g_jac_clamp=(-1,-1), train_alter=False, alter_order=['fgh'], alter_epoch=[1],
          fgh_lr=None, sync_t=False, label_flow_noise=0., label_flow_noise_0=0., invert_transform_t=False, dynamics=linear,
          dynamic_model=None, label_scaler=None):

    if fgh_lr is None:
        fgh_lr = [lr, lr, lr]

    if optimizer == 'adam':
        opt = torch.optim.AdamW
    elif optimizer == 'radam':
        opt = torch.optim.RAdam
    else:
        raise ValueError(f'Optimizer {optimizer} not supported')

    if not train_alter:
        assert alter_order == ['fgh'] and alter_epoch == [1]

    optimizers = []
    for target in alter_order:
        assert type(target) == str, 'alter_order should be a list of strings'
        params = []
        if 'f' in target:
            params.append({'params': net.in_projection.parameters(), 'lr': fgh_lr[0], 'weight_decay': wd[0]})
        if 'g' in target:
            params.append({'params': net.out_projection.parameters(), 'lr': fgh_lr[1], 'weight_decay': wd[1]})
            params.append({'params': net.label_projection.parameters(), 'lr': fgh_lr[1], 'weight_decay': wd[1]})
        if 'h' in target:
            params.append({'params': net.odeblock.parameters(), 'lr': fgh_lr[2], 'weight_decay': wd[2]})
            if dynamic_model is not None:
                params.append({'params': dynamic_model.parameters(), 'lr': fgh_lr[2], 'weight_decay': wd[2]})
        optimizer = opt(params, lr=lr)
        optimizers.append(optimizer)
    assert len(optimizers) == len(alter_epoch), 'optimizers and train_alter_epoch should have same length'

    if scheduler == 'none':
        scheduler = None
    elif scheduler == 'cos':
        total_steps = total_epoch * len(train_dataloader)
        assert len(optimizers) == 1, 'Cosine annealing scheduler only supports single optimizer'
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizers[0], T_max=total_steps, eta_min=0)

    ckpt_dir = wandb.run.dir # os.path.join(wandb.run.dir, 'files')
    os.makedirs(ckpt_dir, exist_ok=True)
    best_test_rmse = 100
    epoch = 0
    logs = {}    
    while epoch < total_epoch:
        for i, (optimizer, epoch_count, alter_type) in enumerate(zip(optimizers, alter_epoch, alter_order)):
            ### single unit of alternating training
            for _ in range(epoch_count):
                ### single epoch training
                net.zero_grad() # this will zero grad for all optimizers (although it will be again zeroed in each optimizer.step()
                # set lambdas
                effective_lambdas = [
                    lam * int(epoch >= start_e) for start_e, lam in zip(loss_start_epoch, lambdas)
                ]
                if method == 'node':
                    loss_func = torch.nn.MSELoss()
                    logs = train_epoch_node(net, train_dataloader, optimizer, scheduler, task_criterion=loss_func, verbose=(epoch % test_every == 0), label_ae_noise=label_ae_noise)
                else:
                    logs = train_epoch(net, train_dataloader, optimizer, scheduler, lambdas=effective_lambdas, task_criterion=task_criterion,
                                        force_zero_prob=force_zero_prob, verbose=(epoch % test_every == 0), label_ae_noise=label_ae_noise,
                                        f_sg_target=f_sg_target, augment_t=augment_t, t_transform=t_transform, label_ae_mse=label_ae_mse,
                                       task_dec=task_dec, f_jac_clamp=f_jac_clamp, g_jac_clamp=g_jac_clamp, sync_t=sync_t,
                                       label_flow_noise=label_flow_noise, label_flow_noise_0=label_flow_noise_0, invert_transform_t=invert_transform_t,
                                       alter_type=alter_type, dynamics=dynamics, dynamic_model=dynamic_model)
                # per-epoch logging
                wandb.log({'train/'+k+'_epoch': v for k, v in logs.items()}, commit=False)
                wandb.log({'epoch': epoch}, commit=False)

                # test logging
                if (epoch+1) % test_every == 0 and test_dataloader is not None:
                    metric_dict = test(net, val_dataloader=None, test_dataloader=test_dataloader, label_scaler=label_scaler)
                    valtric_dict = test(net, test_dataloader=val_dataloader, metric_key='val', label_scaler=label_scaler)
                    metric_dict.update(valtric_dict) # merge val metrics
                    # TODO: change it into rmse
                    rmse_val = metric_dict['val/rmse_dopri']
                    rmse1 = metric_dict['test/rmse_1']
                    rmse2 = metric_dict['test/rmse_2']
                    rmseinf =  metric_dict['test/rmse_dopri']
                    straight = metric_dict['test/straightness']
                    latent_mse =  metric_dict['test/latent_mse']
                    data_mse = metric_dict['test/data_mse']
                    if rmseinf < best_test_rmse:
                        best_test_rmse = rmseinf
                        torch.save(net.state_dict(), os.path.join(ckpt_dir, f'best.ckpt'))
                        if dynamic_model is not None:
                            torch.save(dynamic_model.state_dict(), os.path.join(ckpt_dir, f'best_dynamic.ckpt'))
                    else:
                        torch.save(net.state_dict(), os.path.join(ckpt_dir, f'last.ckpt'))
                        if dynamic_model is not None:
                            torch.save(dynamic_model.state_dict(), os.path.join(ckpt_dir, f'last_dynamic.ckpt'))
                    metric_dict['test/best_rmse'] = best_test_rmse

                    print(f'Epoch {epoch}, Val RMSE {rmse_val:.4f}, Test RMSE 1/2/inf {rmse1:.4f}/{rmse2:.4f}/{rmseinf:.4f}, Straightness {straight:.4f}, Latent MSE {latent_mse:.4f}, Data MSE {data_mse:.4f}')
                    wandb.log(metric_dict, commit=False)
                epoch += 1
                ### end of single epoch training
            ### end of single unit of alternating training
        if epoch >= total_epoch:
            break # doing break here will change actual total epochs trained if total_epoch is not divisible by train_alter_epoch

    # save last checkpoint and finish
    torch.save(net.state_dict(), os.path.join(ckpt_dir, 'last.ckpt'))
    return


def train_epoch_node(net, train_dataloader, optimizer, scheduler=None, task_criterion=torch.nn.MSELoss(), verbose=False, **kwargs):
    net.train()
    loss_epoch = 0
    nfe_epoch = 0
    count = 0
    pbar = tqdm(train_dataloader, leave=verbose)
    for i, (X, Y) in enumerate(pbar):
        optimizer.zero_grad()
        device = net.device
        X, Y = X.to(device), Y.to(device)
        pred = net(X)
        loss = task_criterion(pred, Y)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        loss_epoch += loss.item()
        nfe_epoch += net.odeblock.odefunc.nfe
        count += 1
        wandb.log({
            'train/loss': loss.item(),
            'train/nfe': net.odeblock.odefunc.nfe,
        })
        logs = {
            'loss': loss_epoch / count,
            'nfe': nfe_epoch / count,
        }
        pbar.set_description(
            ', '.join([f'{k}: {v:.3e}' for k, v in logs.items()])
        )
    return logs


@torch.inference_mode()
def test_rmse(net, test_dataloader, method='dopri5', num_timesteps=1+1, return_mse=False, label_scaler=None):
    net.eval()
    count = 0
    latent_mse = data_mse = 0
    rmse = 0
    for i, (X, Y) in tqdm(enumerate(test_dataloader), leave=False, total=len(test_dataloader), desc='Measure rmse'):
        X, Y = X.cuda(), Y.cuda()
        if method == 'dopri5':
            feat, pred = net(X, return_features=True, method='dopri5')
            traj = [feat]
        else:
            traj, pred = net.get_traj(X, method=method, timesteps=num_timesteps)
        count += Y.size(0)
        if label_scaler is not None:
            Y_unnorm = label_scaler.inverse_transform(Y.cpu().numpy())
            pred_unnorm = label_scaler.inverse_transform(pred.cpu().numpy())
            rmse += np.mean((Y_unnorm - pred_unnorm)**2) * Y.size(0)
        else:
            rmse += F.mse_loss(pred, Y).item() * Y.size(0)
        latent_mse += F.mse_loss(traj[-1], net.label_projection(Y)).item() * Y.size(0)
        data_mse += F.mse_loss(pred, Y).item() * Y.size(0)

    #TODO: compute rmse
    latent_mse /= count
    data_mse /= count
    rmse /= count
    rmse = rmse ** 0.5

    if return_mse:
        return rmse, latent_mse, data_mse
    return rmse


@torch.inference_mode()
def test_straightness(net, test_dataloader, normalize=True):
    net.eval()
    straight = 0.
    count = 0
    for i, (X, Y) in tqdm(enumerate(test_dataloader), leave=False, total=len(test_dataloader), desc='Measure Straightness'):
        X, Y = X.cuda(), Y.cuda()
        count += Y.size(0)
        straight += straightness(net, X, normalize=normalize) * Y.size(0)
    straight /= count
    return straight


def test_norm_avg(net, val_dataloader=None):
    if val_dataloader is None:
        return 0, 0
    net.eval()
    # z1 norm avg
    z1_norm_avg = 0
    # z0 norm avg
    z0_norm_avg = 0
    count = 0
    for i, (X, Y) in tqdm(enumerate(val_dataloader), leave=False, total=len(val_dataloader), desc='Measure Norm Avg'):
        X = X.cuda()
        z0 = net.in_projection(X)
        z1 = net.label_projection(Y.cuda())
        z0_norm_avg += norm_avg(z0) * X.size(0)
        z1_norm_avg += norm_avg(z1) * X.size(0)
        count += X.size(0)
    z0_norm_avg /= count
    z1_norm_avg /= count
    return z0_norm_avg, z1_norm_avg


@torch.inference_mode()
def test(net, test_dataloader, val_dataloader=None, metric_key='test', do_dopri=True, label_scaler=None):
    '''
    Returns a dict of metrics.
    - val/rmse (if val_dataloader is not None)
    - test/rmse_1
    - test/rmse_2
    - test/rmse_dopri
    - test/straightness
    - test/latent_mse
    - test/data_mse
    '''
    ret = {}
    if val_dataloader is not None:
        rmse_val = test_rmse(net, val_dataloader, label_scaler=label_scaler)
        ret['val/rmse'] = rmse_val

    straight = test_straightness(net, test_dataloader)
    rmse1 = test_rmse(net, test_dataloader, method='euler', num_timesteps=1+1, label_scaler=label_scaler)
    rmse2 = test_rmse(net, test_dataloader, method='euler', num_timesteps=2+1, label_scaler=label_scaler)
    rmse10 = test_rmse(net, test_dataloader, method='euler', num_timesteps=10+1, label_scaler=label_scaler)
    rmse20 = test_rmse(net, test_dataloader, method='euler', num_timesteps=20+1, label_scaler=label_scaler)
    if do_dopri:
        rmseinf, latent_mse, data_mse = test_rmse(net, test_dataloader, method='dopri5', return_mse=True, label_scaler=label_scaler)
        dopri_nfe = net.odeblock.odefunc.nfe # assume no odesolve after test_rmse / inrmseurate since it only measures last batch
    z0_norm_avg, z1_norm_avg = test_norm_avg(net, test_dataloader)

    ret[f'{metric_key}/straightness'] = straight
    ret[f'{metric_key}/rmse_1'] = rmse1
    ret[f'{metric_key}/rmse_2'] = rmse2
    ret[f'{metric_key}/rmse_10'] = rmse10
    ret[f'{metric_key}/rmse_20'] = rmse20
    if do_dopri:
        ret[f'{metric_key}/rmse_dopri'] = rmseinf
        ret[f'{metric_key}/dopri_nfe'] = dopri_nfe
        ret[f'{metric_key}/latent_mse'] = latent_mse
        ret[f'{metric_key}/data_mse'] = data_mse
    ret[f'{metric_key}/z0_norm_avg'] = z0_norm_avg
    ret[f'{metric_key}/z1_norm_avg'] = z1_norm_avg
    return ret


def train_epoch(net, train_dataloader, optimizer, scheduler=None, lambdas=[1.0, 1.0, 0.0], task_criterion=torch.nn.MSELoss(),
                force_zero_prob=0., verbose=False, label_ae_noise=0., f_sg_target=False, augment_t=1, sync_t=False,
                t_transform='identity', label_ae_mse=True, task_dec=False, f_jac_clamp=(-1, -1), g_jac_clamp=(-1, -1),
                label_flow_noise=0., label_flow_noise_0=0., invert_transform_t=False, alter_type='fgh', dynamics=linear, dynamic_model=None):
    net.train()
    if 'h' not in alter_type:
        net.odeblock.eval()
    loss_epoch = 0
    flow_loss_epoch = 0
    label_ae_loss_epoch = 0
    task_loss_epoch = 0
    count = 0
    pbar = tqdm(train_dataloader, leave=verbose)
    flow_loss_timestep_bin = torch.zeros(11)
    timestep_bin_count = torch.zeros(11)

    if dynamic_model is None:
        dyn_v = vmap(jacfwd(dynamics, argnums=2))
    else:
        def dynamics(params, z0, z1, t):
            return functional_call(dynamic_model, params, (z0, z1, t))
        def dyn_v(params, z0, z1, t):
            return vmap(jacfwd(dynamics, argnums=3), in_dims=(None, 0, 0, 0))(params, z0, z1, t).squeeze()

    for i, (X, Y) in enumerate(pbar):
        device = net.device
        X, Y = X.to(device), Y.to(device)
        z0 = net.in_projection(X)
        z1 = net.label_projection(Y)
        # if train_only_h:
        #     z0 = z0.detach()
        #     z1 = z1.detach()

        # jacobian clamping
        f_jac_clamp_loss = g_jac_clamp_loss = torch.zeros(1).cuda()
        delta_x = X[1:] - X[:-1] # drop one instance but maybe okay
        delta_z = z0[1:] - z0[:-1]
        Q = delta_z.reshape(delta_z.size(0), -1).norm(dim=-1) / delta_x.reshape(delta_x.size(0), -1).norm(dim=-1)
        wandb.log({"train/f_jac_Q": Q.mean().item()}, commit=False)
        if f_jac_clamp[0] >= 0 and f_jac_clamp[1] >= 0:
            lam_min, lam_max = f_jac_clamp  # assume a tuple of min and max values.
            target_min = torch.ones_like(Q) * lam_min
            target_max = torch.ones_like(Q) * lam_max
            L_min = F.mse_loss(torch.minimum(Q, target_min), target_min, reduction='mean')
            L_max = F.mse_loss(torch.maximum(Q, target_max), target_max, reduction='mean')
            f_jac_clamp_loss = L_min + L_max

        delta_y = Y[1:] - Y[:-1]
        delta_z = z1[1:] - z1[:-1]
        Q = delta_z.reshape(delta_z.size(0), -1).norm(dim=-1) / (1e-6 + delta_y.reshape(delta_y.size(0), -1).norm(dim=-1)) # avoid zero division
        # pass if delta_z is zero
        mask = Q > 1e-9
        Q = Q[mask]
        wandb.log({"train/g_jac_Q": Q.mean().item()}, commit=False)
        if g_jac_clamp[0] >= 0 and g_jac_clamp[1] >= 0:
            lam_min, lam_max = g_jac_clamp
            target_min = torch.ones_like(Q) * lam_min
            target_max = torch.ones_like(Q) * lam_max
            L_min = F.mse_loss(torch.minimum(Q, target_min), target_min, reduction='mean')
            L_max = F.mse_loss(torch.maximum(Q, target_max), target_max, reduction='mean')
            g_jac_clamp_loss = L_min + L_max

        # sampling timestep
        if sync_t: # sample one t and use it for all instances
            t = torch.rand(1).to(device).repeat(z0.size(0))
        else:
            t = torch.rand(augment_t * z0.size(0)).to(device)
        # Strategies to give more sampling chance to certain timesteps
        if t_transform == 'square':
            t = t**2
        elif t_transform == 'cubic':
            t = t**3
        elif t_transform == 'one_minus_cos':
            t = 1 - torch.cos(t * np.pi / 2)
        else:
            assert t_transform == 'identity'
        if invert_transform_t:
            t = 1 - t
        t = t * net.t_final # scale t to [0, t_final]
        t = append_dims(t, z0.ndim)
        # make some portion of sampled t to zero
        if force_zero_prob > 0.:
            mask = (torch.rand_like(t) < force_zero_prob).float()
            t = t * (1. - mask)
        
        z0_aug = z0
        z1_aug = z1        
        if augment_t > 1:
            z0_aug = repeat(z0_aug, 'B ... -> (B a) ...', a=augment_t)
            z1_aug = repeat(z1_aug, 'B ... -> (B a) ...', a=augment_t)
        if label_flow_noise > 0.: # add noise to label embedding for flow prediction
            z1_aug = z1_aug + label_flow_noise * torch.randn_like(z1_aug)
        if label_flow_noise_0 > 0.: # add noise to label embedding for flow prediction
            z0_aug = z0_aug + label_flow_noise_0 * torch.randn_like(z0_aug)
        
        # zt = (1-t) * z0_aug + t * z1_aug
        if dynamic_model is None:
            zt = dynamics(z0_aug, z1_aug, t)
            v_target = dyn_v(z0_aug, z1_aug, t).squeeze()
        else:
            params = dict(dynamic_model.named_parameters())
            zt = dynamics(params, z0_aug, z1_aug, t)
            v_target = dyn_v(params, z0_aug, z1_aug, t)

        flow_loss, label_ae_loss, task_loss = (torch.tensor(0.) for _ in range(3))

        boundary_loss_0, boundary_loss_1 = (torch.tensor(0.) for _ in range(2))
        if dynamic_model is not None:
            z0_aug_det = z0_aug.detach()
            z1_aug_det = z1_aug.detach()
            z0_hat = dynamics(params, z0_aug_det, z1_aug_det, torch.zeros_like(t))
            z1_hat = dynamics(params, z0_aug_det, z1_aug_det, torch.ones_like(t))
            boundary_loss_0 = F.mse_loss(z0_hat, z0_aug_det) * lambdas[3]
            boundary_loss_1 = F.mse_loss(z1_hat, z1_aug_det) * lambdas[3]
        # flow loss
        # if lambdas[0] > 0:

        # if f_sg_target:
        #     v_target = z1_aug - z0_aug.detach()
        # else:
        #     v_target = z1_aug - z0_aug
        v_pred = net.pred_v(zt, t)

        # pred_v variance across batch
        v_pred_var = v_pred.detach().reshape(v_pred.size(0), -1).var(dim=-1).mean().item()
        flow_loss = F.mse_loss(v_pred, v_target) * lambdas[0]
        normalized_flow_loss = flow_loss.item() / (v_target.detach().reshape(v_target.size(0), -1).norm(dim=-1).mean() + 1e-9)

        # label autoencoding loss
        # if lambdas[1] > 0:
        z1_noised = z1
        if label_ae_noise > 0.:
            z1_noised = z1 + label_ae_noise * torch.randn_like(z1)
        # if train_only_h:
        #     y_pred = net.out_projection(z1_noised).detach()
        # else:
        y_pred = net.out_projection(z1_noised)
        if label_ae_mse:
            label_ae_loss = F.mse_loss(y_pred, Y) * lambdas[1]
        else:
            assert 0, 'Not allowed'

        # task loss
        if lambdas[2] > 0:
            assert augment_t == 1
            z1_hat = zt + v_pred * (1-t)
            if task_dec:
                z1_hat = z1_hat.detach()
            y_hat = net.out_projection(z1_hat)
            task_loss = task_criterion(y_hat, Y) * lambdas[2]

        # optimizer step
        net.zero_grad()
        # loss = boundary_loss_0 + boundary_loss_1
        loss = flow_loss + label_ae_loss + task_loss + f_jac_clamp_loss + g_jac_clamp_loss +\
            boundary_loss_0 + boundary_loss_1
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        # logging
        wandb.log({
            'learning_rate': optimizer.param_groups[0]['lr'],
            'train/loss': loss.item(),
            'train/flow_loss': flow_loss.item(),
            'train/norm_flow_loss: ': normalized_flow_loss.item(), # normalized flow loss
            'train/label_ae_loss': label_ae_loss.item(),
            'train/task_loss': task_loss.item(),
            'train/f_jac_clamp_loss': f_jac_clamp_loss.item(),
            'train/g_jac_clamp_loss': g_jac_clamp_loss.item(),
            'train/flow_pred_var': v_pred_var,
            'train/z0_norm': z0.detach().reshape(z0.size(0), -1).norm(dim=-1).mean().item(),
            'train/z1_norm': z1.detach().reshape(z1.size(0), -1).norm(dim=-1).mean().item(),
            'train/boundary_loss_0': boundary_loss_0.item(),
            'train/boundary_loss_1': boundary_loss_1.item(),
        })

        # flow loss timestep bin
        t_detach = t.detach().squeeze()
        t_candidates = torch.linspace(0, 1, len(timestep_bin_count)).to(t_detach.device)
        flow_loss_detach = F.mse_loss(v_pred.detach(), v_target.detach(), reduction='none') * lambdas[0]
        flow_loss_detach = reduce(flow_loss_detach, 'B ... -> B', 'mean')
        for i in range(len(t_candidates)):
            if i == len(t_candidates) - 1:
                break
            t_start, t_end = t_candidates[i], t_candidates[i+1]
            mask = (t_detach >= t_start) & (t_detach < t_end)
            timestep_bin_count[i] += mask.sum().cpu()
            flow_loss_timestep_bin[i] += flow_loss_detach[mask].sum().cpu()

        loss_epoch += loss.item()
        flow_loss_epoch += flow_loss.item()
        label_ae_loss_epoch += label_ae_loss.item()
        task_loss_epoch += task_loss.item()
        count += 1
        logs = {
            'loss': loss_epoch / count,
            'flow_loss': flow_loss_epoch / count,
            'label_ae_loss': label_ae_loss_epoch / count,
            'task_loss': task_loss_epoch / count,
            'lr': optimizer.param_groups[0]['lr'],
        }
        pbar.set_description(
            ', '.join([f'{k}: {v:.3e}' for k, v in logs.items()])
        )
    # log timestep bin
    flow_loss_timestep_bin = torch.nan_to_num(flow_loss_timestep_bin / timestep_bin_count, posinf=0)
    fig = wandb.Plotly(px.bar(x=list(range(len(flow_loss_timestep_bin))), y=flow_loss_timestep_bin.tolist()))
    wandb.log({'train/flow_loss_timestep_bin': fig}, commit=False)
    if dynamic_model is not None:
        for n, p in dynamic_model.named_parameters():
            hist = wandb.Histogram(p.detach().view(-1).cpu().numpy())
            wandb.log({'parameters/dynamic_model.'+n: hist}, commit=False)

    return logs



if __name__ == '__main__':
    # argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str, required=True)
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--task', type=str, default='bostonHousing',)
    parser.add_argument('--split_num', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--lr_scheduler', type=str, default='none', choices=['none', 'cos'])
    parser.add_argument('--fgh_lr', type=eval, default=None, help='Learning rate for f, g, h. e.g [1e-3, 1e-3, 1e-3]')
    parser.add_argument('--fgh_lr_rel', type=eval, default=None, help='Learning rate for f, g, h. Relative to lr. e.g [1, 1, 1]')
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--lambdas', type=eval, default=[1., 1., 0., 0.])
    parser.add_argument('--f_jac_clamp', type=eval, default=[-1, -1], help='Clamp jacobian norm to this range, -1 means no clamping')
    parser.add_argument('--g_jac_clamp', type=eval, default=[-1, -1], help='Clamp jacobian norm to this range, -1 means no clamping')
    parser.add_argument('--label_proj_type', type=str, default='linear', choices=['linear', 'mlp', 'identity'])
    parser.add_argument('--test_every', type=int, default=50)
    parser.add_argument('--f_sg_target', action='store_true', help='Detach f in velocity target from computational graph')
    parser.add_argument('--latent_chan', type=int, default=64)
    parser.add_argument('--h_dim', type=int, default=0)
    parser.add_argument('--h_add_blocks', type=int, default=0)
    parser.add_argument('--f_add_blocks', type=int, default=0)
    parser.add_argument('--fixnorm', action='store_true', help='Use fixed architecture not ending with norm')
    parser.add_argument('--no_final_norm', action='store_true', help='Do not use final norm layer at odefunc')
    parser.add_argument('--no_out_norm', action='store_true', help='Do not use starting GroupNorm layer at out_projection')
    parser.add_argument('--t_transform', type=str, default='identity', choices=['identity', 'square', 'one_minus_cos', 'cubic'])
    parser.add_argument('--invert_transform_t', action='store_true', help='transform t in inverse way (use 1-t instead of t)')
    parser.add_argument('--in_proj_type', type=str, default='linear', choices=['linear', 'identity', 'mlp', 'conv1x1', 'conv3x3'])
    parser.add_argument('--out_proj_type', type=str, default='linear', choices=['linear', 'mlp', 'mlp2'])
    parser.add_argument('--mid_conv', type=int, default=1, help='Number of conv layers in the middle of ConvODENet')
    parser.add_argument('--train_alter', action='store_true', help='Train modules in alternating order')
    parser.add_argument('--train_alter_order', type=str, default='fgh', help='Alternating update order, underscore-sepatated. e.g. fg_h')
    parser.add_argument('--train_alter_epoch', type=str, default='1', help='Alternating update epochs, underscore-separated. e.g. 1_3')
    parser.add_argument('--sync_t', action='store_true', help='Use same t for all instances in a batch')
    parser.add_argument('--augment_t', type=int, default=1)
    parser.add_argument('--label_flow_noise', type=float, default=0., help='Add noise to z1 for flow prediction')
    parser.add_argument('--label_flow_noise_0', type=float, default=0., help='Add noise to z0 for flow prediction')
    parser.add_argument('--t_final', type=float, default=1., help='Train and test with [0, t_final] instead of [0, 1]')
    parser.add_argument('--in_latent_chan', type=int, default=64, help='Input latent channel for OurModel')
    parser.add_argument('--f_act', type=str, default='relu', choices=['relu', 'swish'])
    parser.add_argument('--h_act', type=str, default='relu', choices=['relu', 'swish'])
    parser.add_argument('--dyn_use_norm', action='store_true', help='Use norm in dynamic model')
    parser.add_argument('--dyn_layers', type=int, default=3, help='Number of layers in dynamic model')
    parser.add_argument('--h_dropout', type=float, default=0.0)
    parser.add_argument('--dynamics', type=str, default='linear', choices=['linear', 'cos', 'vp_ode', 'lin_cos', 'const_vp_ode', 'learnable', 'half_circle', 'lin_sin', 'sin'])
    parser.add_argument('--mlp_hidden_dim', type=int, default=512)
    parser.add_argument('--time_modulation', type=str, choices=['none', 'fourier', 'adaln'], default='none')
    ### not so frequently used...
    parser.add_argument('--in_proj_scale', type=float, default=None)
    parser.add_argument('--label_proj_scale', type=float, default=None)
    parser.add_argument('--proj_norm', type=str, default='none', choices=['none', 'ln', 'bn'])
    parser.add_argument('--debug', action='store_true', help='debug mode with 10 instance val set')
    parser.add_argument('--force_zero_prob', type=float, default=0.)
    parser.add_argument('--anode_net', action='store_true', help='Use ConvODENet used in ANODE')
    parser.add_argument('--aug_dim', type=int, default=1)
    parser.add_argument('--label_ae_noise', type=float, default=0.)
    parser.add_argument('--dataset', type=str, choices=['cifar10', 'mnist'], default='cifar10')
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--total_steps', type=int, default=-1)
    parser.add_argument('--weight_decay', type=eval, default=[0., 0., 0.])
    parser.add_argument('--method', type=str, choices=['ours', 'node'], default='ours')
    parser.add_argument('--label_ae_criterion', type=str, choices=['ce', 'mse'], default='mse')
    parser.add_argument('--task_dec', action='store_true', help='Use task loss only for decoder')
    ### deprecated
    # parser.add_argument('--train_fg_epoch', type=int, default=5)
    # parser.add_argument('--train_h_epoch', type=int, default=5)
    # parser.add_argument('--train_fg_h_epoch', type=eval, default=(-1, -1), help='Same as train_fg_epoch and train_h_epoch but set with one argument')
    args = parser.parse_args()

    wandb.init(project="SFNO_UCI", name=args.name, entity='denoising_representation')

    args.train_alter_order = args.train_alter_order.split('_') # f_gh -> ['f', 'gh'], fgh -> ['fgh']
    args.train_alter_epoch = [int(e) for e in args.train_alter_epoch.split('_')]
    print('Alternating:', args.train_alter_order, args.train_alter_epoch)
    assert len(args.train_alter_order) == len(args.train_alter_epoch), 'train_alter_order and train_alter_epoch should have same length'
        
    if args.fgh_lr_rel is not None:
        assert args.fgh_lr is None, 'fgh_lr and fgh_lr_rel cannot be set at the same time'
        args.fgh_lr = [args.lr * lr_rel for lr_rel in args.fgh_lr_rel]

    # configs
    config = OmegaConf.create(vars(args))
    OmegaConf.save(config, os.path.join(wandb.run.dir, 'exp_config.yaml'))
    wandb.config.update(OmegaConf.to_container(config))

    seed = args.seed
    optim = args.optimizer
    lr = args.lr
    force_zero_prob = args.force_zero_prob
    weight_decay = args.weight_decay
    lambdas = args.lambdas
    assert not args.task_dec or lambdas[2] > 0, 'task_dec requires task loss'

    label_ae_mse = True
    fix_random_seeds(seed, strict=True)

    train_dataset = UCI('UCI_Datasets', args.task, args.split_num, 'train')
    val_dataset = UCI('UCI_Datasets', args.task, args.split_num, 'val')
    test_dataset = UCI('UCI_Datasets', args.task, args.split_num, 'test')

    label_scaler = train_dataset.scaler_y
    data_dim = train_dataset.train_dim_x
    output_dim = train_dataset.train_dim_y

    net = ODENet(device='cuda', data_dim=data_dim, hidden_dim=args.mlp_hidden_dim, output_dim=output_dim, augment_dim=0, time_dependent=True,
                in_proj=args.in_proj_type, out_proj=args.out_proj_type, label_proj=args.label_proj_type, proj_norm=args.proj_norm,
                in_proj_scale=args.in_proj_scale, label_proj_scale=args.label_proj_scale, t_final=args.t_final,
                time_modulation=args.time_modulation).cuda()
    
    if args.dynamics == 'learnable':
        simulator = Simulator(data_dim, args.dyn_layers, use_norm=args.dyn_use_norm).cuda()
    else:
        simulator = None

    if args.debug:
        args.batch_size = 100
        args.epochs = 10000
        args.test_every = 50
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

    # how long to train
    if args.total_steps != -1:
        args.epochs = args.total_steps // len(train_loader) + int(args.total_steps % len(train_loader) > 0)
    
    wandb.watch(net, log='all', log_freq=50, idx=0)
    train(net, train_loader, optimizer=optim, scheduler=args.lr_scheduler, lr=lr, total_epoch=args.epochs, force_zero_prob=force_zero_prob, lambdas=lambdas, wd=weight_decay,
          val_dataloader=val_loader, test_dataloader=test_loader, test_every=args.test_every, task_criterion=custom_ce,
          label_ae_noise=args.label_ae_noise, method=args.method, f_sg_target=args.f_sg_target, augment_t=args.augment_t,
          t_transform=args.t_transform, label_ae_mse=label_ae_mse, task_dec=args.task_dec, f_jac_clamp=args.f_jac_clamp,
          g_jac_clamp=args.g_jac_clamp, train_alter=args.train_alter, alter_order=args.train_alter_order, alter_epoch=args.train_alter_epoch,
          fgh_lr=args.fgh_lr, sync_t=args.sync_t, label_flow_noise=args.label_flow_noise,
          label_flow_noise_0=args.label_flow_noise_0, invert_transform_t=args.invert_transform_t, dynamics=eval(args.dynamics),
          dynamic_model=simulator, label_scaler=label_scaler)

    metric_dict = test(net, test_dataloader=test_loader, label_scaler=label_scaler)
    rmse1 = metric_dict['test/rmse_1']
    rmse2 = metric_dict['test/rmse_2']
    rmseinf = metric_dict['test/rmse_dopri']
    straight = metric_dict['test/straightness']
    latent_mse = metric_dict['test/latent_mse']
    data_mse = metric_dict['test/data_mse']
    print(f'Test RMSE 1/2/inf {rmse1:.4f}/{rmse2:.4f}/{rmseinf:.4f}, Straightness {straight:.4f}, Latent MSE {latent_mse:.4f}, Data MSE {data_mse:.4f}')
    wandb.finish()
    print(f'Done.')