# %%
from tqdm import tqdm
from utils import fix_random_seeds, custom_ce, area_scaling_factor
import torch.nn.functional as F
from ANODE.experiments.dataloaders import ConcentricSphere
from ANODE.viz.plots import decision_boundary, single_feature_plt_plotly, trajectory_plt_plotly
from models.mlp_model import ODENet, PreservingLinear
import argparse
import torch
import os
NUM_POINTS = 3000

# %%

# generate data
fix_random_seeds(42) # fixed seed for data generation

num_points_inner = NUM_POINTS // 3
num_points_outer = NUM_POINTS - num_points_inner
data_concentric = ConcentricSphere(2, inner_range=(0., .5), outer_range=(1., 1.5),
                                   num_points_inner=num_points_inner, num_points_outer=num_points_outer)
Z_0 = torch.stack(data_concentric.data)
Z_1 = F.one_hot(torch.stack(data_concentric.targets).clamp(0, 1).long().squeeze(), num_classes=2).float()


# %%
def train_net(net, z_0, z_1, num_epochs=100, lambdas=[1.0, 1.0, 1.0], criterion=torch.nn.MSELoss(), lr=1e-3):
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    z_0 = z_0.cuda()
    z_1 = z_1.cuda()

    nfe_log = []
    loss_log = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        optimizer.zero_grad()
        z_pred = net(z_0)
        loss = criterion(z_pred, z_1)
        loss.backward()
        loss_log.append(loss.item())
        nfe_log.append(net.odeblock.odefunc.nfe)
        optimizer.step()
        pbar.set_description(f'Loss: {loss.item():.4f} NFE: {net.odeblock.odefunc.nfe}')
    return net, nfe_log, loss_log


def train_flow(net, x_0, x_1, num_epochs=100, lambdas=[1.0, 1.0, 1.0], criterion=torch.nn.MSELoss()):
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

    x_0 = x_0.cuda()
    x_1 = x_1.cuda()

    nfe_log = []
    loss_log = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        optimizer.zero_grad()
        z0 = net.in_projection(x_0)
        z1 = net.label_projection(x_1)
        t = torch.rand(x_0.size(0), 1).to(x_0)
        zt = (1-t) * z0 + t * z1
        v_target = z1 - z0
        v_pred = net.pred_v(zt, t)
        loss = criterion(v_pred, v_target)
        loss.backward()
        loss_log.append(loss.item())
        nfe_log.append(net.odeblock.odefunc.nfe)
        optimizer.step()
        pbar.set_description(f'Loss: {loss.item():.4f} NFE: {net.odeblock.odefunc.nfe}')
    return net, nfe_log, loss_log


def train_ours(net, x_0, x_1, num_epochs=100, lambdas=[1.0, 1.0, 1.0, 0.0], criterion=torch.nn.MSELoss(),
               wd=0.0, flow_f_sg='no', lr=1e-3, zero_out_prob=0., loss_start_epoch=[0, 0, 0, 0], 
               val_every=1000, save_dir=None):
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=wd)

    x_0 = x_0.cuda()
    x_1 = x_1.cuda()

    nfe_log = []
    loss_log = []
    flow_loss_log = []
    label_ae_loss_log = []
    ratio_log = []
    variance_log = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        optimizer.zero_grad()
        z0 = net.in_projection(x_0)
        z1 = net.label_projection(x_1)
        # t=0 sampling with certain probability
        t = torch.rand(x_0.size(0), 1).to(x_0)
        mask = (torch.rand(x_0.size(0), 1) < zero_out_prob).float().to(x_0)
        t = t * mask
        # velocity prediction
        if flow_f_sg in ['first']:
            zt = (1-t) * z0.detach() + t * z1
        else:
            zt = (1-t) * z0 + t * z1
        v_pred = net.pred_v(zt, t)
        v_pred_var = v_pred.detach().reshape(v_pred.shape[0], -1).var(dim=0).mean().item()
        # target velocity
        if flow_f_sg in ['second']:
            v_target = z1 - z0.detach()
        else:
            v_target = z1 - z0
        flow_loss = F.mse_loss(v_pred, v_target)
        inter_loss = 0. # no use
        x_hat = net.out_projection((1-t) * z0 + t * z1 + v_pred * (1-t))
        task_loss = criterion(x_hat, x_1)
        label_autoencoder_loss = F.mse_loss(net.out_projection(z1), x_1)  # label autoencoder loss
        effective_lambdas = [lambdas[i] if epoch >= loss_start_epoch[i] else 0. for i in range(4)]
        loss = effective_lambdas[0] * flow_loss + effective_lambdas[1] * inter_loss + effective_lambdas[2] * task_loss \
            + effective_lambdas[3] * label_autoencoder_loss
        loss.backward()
        loss_log.append(loss.item())
        flow_loss_log.append(flow_loss.item())
        label_ae_loss_log.append(label_autoencoder_loss.item())
        nfe_log.append(1)
        optimizer.step()
        ratio_log.append(area_scaling_factor(net.in_projection.weight))
        variance_log.append(v_pred_var)
        pbar.set_description(f'Loss: {loss.item():.3e} Flow Loss: {flow_loss.item():.3e} Label AE: {label_autoencoder_loss.item():.3e} Task Loss: {task_loss.item():.3e}')
        if (epoch + 1) % val_every == 0 and save_dir is not None:
            # checkpoint
            torch.save(net.state_dict(), os.path.join(save_dir, f'epoch_{epoch}.pt'))
            with torch.no_grad():
                # decision boundary
                decision_boundary(net, save_fig=f'{save_dir}/epoch_{epoch}.png')
                # embedding space and traj visualization
                feat = net.in_projection(x_0)
                single_feature_plt_plotly(feat.to(x_1.device), x_1, color_targets=x_0, stride=10, save_fig=f'{save_dir}/epoch_{epoch}_feat.png')
                batch_stride = 20
                traj, out = net.get_traj(x_0[::batch_stride], timesteps=1+1, method='dopri5')
                t = torch.linspace(0, 1, 100).view(-1, 1).repeat(1, x_0[::batch_stride].shape[0])
                trajectory_plt_plotly(traj, t, x_0[::batch_stride], x_1[::batch_stride], net=net, save_fig=f'{save_dir}/epoch_{epoch}_traj.png')
            # save logs
            torch.save({'nfe': nfe_log, 
                        'loss': loss_log,
                        'flow_loss': flow_loss_log,
                        'label_ae_loss': label_ae_loss_log,
                        'ratio': ratio_log,
                        'v_pred_var': variance_log,
                        }, os.path.join(save_dir, 'logs.pt'))


    # save logs
    torch.save({'nfe': nfe_log, 
                'loss': loss_log,
                'flow_loss': flow_loss_log,
                'label_ae_loss': label_ae_loss_log,
                'ratio': ratio_log,
                'v_pred_var': variance_log,
                }, os.path.join(save_dir, 'logs.pt'))

    return net, nfe_log, loss_log


# %%

if __name__ == "__main__":
    # argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, nargs='+', default=[42])
    parser.add_argument('--force_zero_prob', type=float, default=0.)
    parser.add_argument('--name', required=True)
    parser.add_argument('--loss', type=str, nargs='+', choices=['flow', 'label_ae', 'task'], default=['flow', 'label_ae'])
    parser.add_argument('--preserve_f', action='store_true')
    parser.add_argument('--flow_f_sg', choices=['no', 'first', 'second'], default='no', 
                        help='Where to apply stop-graident in flow loss.')
    parser.add_argument('--epochs', type=int, default=50000)
    parser.add_argument('--val_every', type=int, default=1000)
    parser.add_argument('--ortho', action='store_true')
    args = parser.parse_args()

    LATENT_DIM = 3
    lambdas = [
        int('flow' in args.loss),
        0, # inter
        int('task' in args.loss),
        int('label_ae' in args.loss)
    ]
    print('lambdas:', lambdas)
    for seed in args.seed:
        # directory init
        save_dir = os.path.join('logs', 'concentric', args.name, f'seed_{seed}')
        os.makedirs(save_dir)
        print(f"==== Run with seed {seed}, Save at {save_dir} ====")

        # model init
        fix_random_seeds(seed)
        if args.preserve_f:
            enc_f = PreservingLinear(2, LATENT_DIM)
        else:
            enc_f = torch.nn.Linear(2, LATENT_DIM)
            if args.ortho:
                torch.nn.init.orthogonal_(enc_f.weight.data)
        dec_g = torch.nn.Linear(LATENT_DIM, 2)
        enc_g = torch.nn.Linear(2, LATENT_DIM)
        net = ODENet('cuda', LATENT_DIM, 64, 2, 0, True, 'relu', in_proj=enc_f, out_proj=dec_g, label_proj=enc_g).cuda()
        net.train()

        # train init
        fix_random_seeds(seed, strict=True)
        train_func = train_ours
        # train_func = train_net
        net, nfe_log, loss_log = train_func(net, Z_0, Z_1, args.epochs, lambdas=lambdas, criterion=custom_ce,
                                            wd=0., lr=1e-3, loss_start_epoch=[0, 0, 0, 0], zero_out_prob= 1 - args.force_zero_prob,
                                            flow_f_sg=args.flow_f_sg, val_every=args.val_every, save_dir=save_dir)

    print('Done.')
