"""Lightning module for letters dataset."""

import os
import time
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import wandb

from typing import *
from util.ot_sampler import *
from torchdyn.core import NeuralODE
from util.distribution_distances import compute_distribution_distances
from models.components.mlp import MLP_cond, Flow, Flow2, torch_wrapper, torch_wrapper_flow_cond
from src.models.components.gnn import GlobalGNN, GlobalGNN2

from pathlib import Path

script_directory = Path(__file__).absolute().parent
sccfm_directory = script_directory.parent.parent

class LettersMLPFM(pl.LightningModule):
    def __init__(
        self,
        lr=1e-3,
        dim=2,
        num_hidden=512,
        base="source",
        integrate_time_steps=100,
        name="mlp_fm",
    ) -> None:
        super().__init__()
        
        # Important: This property controls manual optimization.
        self.automatic_optimization = True
        
        self.save_hyperparameters()

        #self.model = Flow(D=2, num_hidden=512).cuda()
        self.model = Flow2(D=2, num_hidden=512).cuda()
        self.lr = lr
        self.dim = dim
        self.num_hidden = num_hidden
        self.integrate_time_steps = integrate_time_steps

        assert base in [
            "source",
            "gaussian",
        ], "Invalid base. Must be either 'source' or 'gaussian'"
        self.base = base
        self.name = name
        
        # for training data eval
        self.num_train_evals = 6
        self.train_evals_count = 0
        self.train_eval_batches = []
        
        self.predict_count = 0

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        return optimizer

    def forward(self, t, x):
        return self.model(t.squeeze(-1), x)

    def compute_loss(self, source_samples, target_samples):
        t = torch.rand_like(source_samples[..., 0, None])

        if self.base == "source":
            x = (1.0 - t) * source_samples + t * target_samples
            u = target_samples - source_samples
            b = self.forward(t, x)
            loss = b.norm(dim=-1) ** 2 - 2.0 * (b * u).sum(dim=-1)
        elif self.base == "gaussian":
            z = torch.randn_like(target_samples)
            x = (1.0 - t) * z + t * target_samples
            u = target_samples - z
            b = self.forward(t, x)
            loss = ((b - u) ** 2).sum(dim=-1)
        else:
            raise ValueError(f"unknown base: {self.base}")

        loss = loss.mean()
        return loss

    def training_step(self, batch, batch_idx):
        _, x0, x1 = batch
        assert (
            len(x0.shape) == 3
        ), "This was a temporary fix for the dataloader -- TODO: Make the code more gener."
        loss = self.compute_loss(x0.squeeze(0), x1.squeeze(0))
        self.log("train/loss", loss, on_step=False, on_epoch=True)
        if (
            (self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)) == 0
            and self.train_evals_count < self.num_train_evals
        ):
            self.train_eval_batches.append(batch)
            self.train_evals_count += 1
        return loss
    
    def training_epoch_end(self, outputs):
        #if (self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)) == 0:
        if self.current_epoch == self.trainer.max_epochs - 1:
            trajs, preds, trues = [], [], []
            for batch in self.train_eval_batches:
                traj, pred, true = self.eval_batch(batch)
                trajs.append(traj)
                preds.append(pred)
                trues.append(true)

            eval_metrics = []
            for i in range(len(trues)):
                pred_samples, target_samples = preds[i], trues[i]
                names, dd = compute_distribution_distances(
                    pred_samples.squeeze(0).unsqueeze(1).to(target_samples),
                    target_samples.squeeze(0).unsqueeze(1),
                )
                eval_metrics.append({**dict(zip(names, dd))})
            eval_metrics = {
                k: np.mean([m[k] for m in eval_metrics]) for k in eval_metrics[0]
            }
            for key, value in eval_metrics.items():
                self.log(f"train/{key}", value, on_step=False, on_epoch=True)
            self.plot(
                trajs,
                self.train_eval_batches,
                num_row=6,
                num_step=3,
                tag="fm_train_samples_6_plots",
                shuffle=True,
            )
            self.plot(
                trajs,
                self.train_eval_batches,
                num_row=2,
                num_step=3,
                tag="fm_train_samples_2_plots",
                shuffle=True,
            )
            self.train_eval_batches = []
            self.train_evals_count = 0
            
    def predict_step(self, batch, batch_idx):
        # use to return predictions for final plots
        if self.predict_count < 60:
            idx, x0, _ = batch
            trajs, pred, true = self.eval_batch(batch)
            self.predict_count += 1
            return idx, trajs, x0, pred, true
        else:
            pass

    def validation_step(self, batch, batch_idx):
        trajs, pred, true = self.eval_batch(batch)
        eval_metrics_ood = self.compute_eval_metrics(pred, true, aggregate=False)
        eval_metrics_mean_ood = {
            k: np.mean([m[k] for m in eval_metrics_ood]) for k in eval_metrics_ood[0]
        }
        for key, value in eval_metrics_mean_ood.items():
            self.log(f"val/{key}", value, on_step=False, on_epoch=True)
        self.plot(trajs, batch, num_row=6, num_step=3, tag="fm_val_samples_6_plots", shuffle=True)
        self.plot(trajs, batch, num_row=2, num_step=3, tag="fm_val_samples_2_plots", shuffle=True)

    def test_step(self, batch, batch_idx):
        trajs, pred, true = self.eval_batch(batch)
        eval_metrics_ood = self.compute_eval_metrics(pred, true, aggregate=False)
        eval_metrics_mean_ood = {
            k: np.mean([m[k] for m in eval_metrics_ood]) for k in eval_metrics_ood[0]
        }
        for key, value in eval_metrics_mean_ood.items():
            self.log(f"test/{key}", value, on_step=False, on_epoch=True)
        self.plot(trajs, batch, num_row=6, num_step=3, tag="fm_test_samples_6_plots", shuffle=True)
        self.plot(trajs, batch, num_row=2, num_step=3, tag="fm_test_samples_2_plots", shuffle=True)

    def compute_eval_metrics(self, pred_samples, target_samples, aggregate=True):
        metrics = []
        for i in range(target_samples.shape[0]):
            names, dd = compute_distribution_distances(
                pred_samples[i].unsqueeze(1).to(target_samples),
                target_samples[i].unsqueeze(1),
            )
            metrics.append({**dict(zip(names, dd))})
        if aggregate:
            metrics = {k: np.mean([m[k] for m in metrics]) for k in metrics[0]}
        return metrics

    def eval_batch(self, batch):
        _, x0, x1 = batch

        node = NeuralODE(
            torch_wrapper(self.model),
            solver='dopri5', #'rk4'
            sensitivity="adjoint",
            atol=1e-4,
            rtol=1e-4,
        )

        time_span = torch.linspace(0, 1, self.integrate_time_steps)

        with torch.no_grad():
            if len(x0.shape) > 3:
                x0 = x0.squeeze()
                x1 = x1.squeeze()

            trajs, pred = [], []
            for i in range(x0.shape[0]):
                x0_i = x0[i]
                traj = node.trajectory(x0_i, t_span=time_span)
                trajs.append(traj)
                pred.append(traj[-1, :, :])

            trajs = torch.stack(trajs, dim=0)
            pred = torch.stack(pred, dim=0)
            true = x1
        return trajs, pred, true

    def plot(
        self, trajs, samples, num_row=6, num_step=5, tag="fm_samples", shuffle=True
    ):
        if not isinstance(trajs, list):
            trajs = trajs.cpu().detach().numpy()
            _, source, target = samples
        else:
            trajs_tmp = []
            for traj in trajs:
                trajs_tmp.append(traj.cpu().detach().numpy().squeeze(0))
            trajs = trajs_tmp

        if self.base == "source":
            num_col = 1 + num_step
        elif self.base == "gaussian":
            num_col = 2 + num_step
        else:
            raise ValueError(f"unknown base: {self.base}")

        fig, axs = plt.subplots(
            num_row,
            num_col,
            figsize=(num_col, num_row),
            gridspec_kw={"wspace": 0.0, "hspace": 0.0},
        )
        axs = axs.reshape(num_row, num_col)

        for i in range(num_row):
            ax = axs[i, 0]

            n = 1500
            rng = np.random.default_rng(42)

            if not isinstance(trajs, list):
                idcs = rng.choice(
                    np.arange(source[i].shape[0]),
                    size=min(n, source[i].shape[0]),
                    replace=False,
                )
                source_samples = source[i].cpu().numpy()
                target_samples = target[i].cpu().numpy()
                source_samples = source_samples[idcs]
                target_samples = target_samples[idcs]
            else:
                source_samples = samples[i][1].squeeze(0).cpu().numpy()
                target_samples = samples[i][2].squeeze(0).cpu().numpy()
                idcs = rng.choice(
                    np.arange(source_samples.shape[0]),
                    size=min(n, source_samples.shape[0]),
                    replace=False,
                )
                source_samples = source_samples[idcs]
                target_samples = target_samples[idcs]

            ax.scatter(*source_samples.T, s=1, c="#3283FB",rasterized=True)
            ax.set_facecolor((206 / 256, 206 / 256, 229 / 256))

            ax = axs[i, -1]
            ax.scatter(*target_samples.T, s=1, c="#3283FB", rasterized=True)
            ax.set_facecolor((206 / 256, 206 / 256, 229 / 256))

            traj = trajs[i]

            t_step = int(traj.shape[0] / (num_step - 1))
            start_j = 1 if self.base == "source" else 0
            ts = np.arange(t_step, t_step * num_step, t_step)
            for j in range(start_j, num_step):
                t = ts[j - 1] - 1
                offset = 0 if self.base == "source" else 1
                ax = axs[i, j + offset]
                ax.scatter(*traj[t, idcs].T, s=1, c="#3283FB", rasterized=True)
                ax.set_facecolor((206 / 256, 206 / 256, 229 / 256))
                if i == 0:
                    time = t / (t_step * (num_step - 1))
                    ax.set_title(f"t={time:.2f}")

        axs[0, 0].set_title("source")
        axs[0, -1].set_title("target")

        for ax in axs.ravel():
            ax.set_xlim(-4, 4)
            ax.set_ylim(-4, 4)
            ax.set_xticks([])
            ax.set_yticks([])

        fig.tight_layout()
        fname = f"{sccfm_directory}/figs/{tag}.pdf"
        fig.savefig(
            fname, bbox_inches="tight", pad_inches=0.0, transparent=True, dpi=300
        )
        wandb.log({f"imgs/{tag}": wandb.Image(fig)})
        plt.close(fig)


class LettersCondMLPFM(pl.LightningModule):
    def __init__(
        self,
        lr=1e-3,
        dim=2,
        num_hidden=512,
        base="source",
        integrate_time_steps=100,
        num_conditions=262,
        name="mlp_fm",
    ) -> None:
        super().__init__()

        # Important: This property controls manual optimization.
        self.automatic_optimization = True

        self.save_hyperparameters()

        #self.model = Flow(D=2, num_hidden=512, num_conditions=num_conditions).cuda()
        self.model = Flow2(D=2, num_hidden=512, num_conditions=num_conditions).cuda()
        self.lr = lr
        self.dim = dim
        self.num_hidden = num_hidden
        self.integrate_time_steps = integrate_time_steps
        self.num_conditions = num_conditions

        assert base in [
            "source",
            "gaussian",
        ], "Invalid base. Must be either 'source' or 'gaussian'"
        self.base = base
        self.name = name

        # for training data eval
        self.num_train_evals = 6
        self.train_evals_count = 0
        self.train_eval_batches = []
        
        self.predict_count = 0

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        return optimizer

    def forward(self, t, x):
        return self.model(t.squeeze(-1), x)

    def compute_loss(self, source_samples, target_samples, cond):
        t = torch.rand_like(source_samples[..., 0, None])

        if self.base == "source":
            x = (1.0 - t) * source_samples + t * target_samples
            u = target_samples - source_samples
            b = self.forward(t, torch.cat((x, cond), dim=-1).float())
            loss = b.norm(dim=-1) ** 2 - 2.0 * (b * u).sum(dim=-1)
        elif self.base == "gaussian":
            z = torch.randn_like(target_samples)
            x = (1.0 - t) * z + t * target_samples
            u = target_samples - z
            b = self.forward(t, torch.cat((x, cond), dim=-1).float())
            loss = ((b - u) ** 2).sum(dim=-1)
        else:
            raise ValueError(f"unknown base: {self.base}")

        loss = loss.mean()
        return loss

    def training_step(self, batch, batch_idx):
        _, x0, x1, cond = batch
        assert (
            len(x0.shape) == 3
        ), "This was a temporary fix for the dataloader -- TODO: Make the code more gener."
        loss = self.compute_loss(x0.squeeze(0), x1.squeeze(0), cond.squeeze(0))
        self.log("train/loss", loss, on_step=False, on_epoch=True)
        if (
            self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)
        ) == 0 and self.train_evals_count < self.num_train_evals:
            self.train_eval_batches.append(batch)
            self.train_evals_count += 1
        return loss

    def training_epoch_end(self, outputs):
        #if (self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)) == 0:
        if self.current_epoch == self.trainer.max_epochs - 1:
            trajs, preds, trues = [], [], []
            for batch in self.train_eval_batches:
                traj, pred, true = self.eval_batch(batch)
                trajs.append(traj)
                preds.append(pred)
                trues.append(true)

            eval_metrics = []
            for i in range(len(trues)):
                pred_samples, target_samples = preds[i], trues[i]
                names, dd = compute_distribution_distances(
                    pred_samples.squeeze(0).unsqueeze(1).to(target_samples),
                    target_samples.squeeze(0).unsqueeze(1),
                )
                eval_metrics.append({**dict(zip(names, dd))})
            eval_metrics = {
                k: np.mean([m[k] for m in eval_metrics]) for k in eval_metrics[0]
            }
            for key, value in eval_metrics.items():
                self.log(f"train/{key}", value, on_step=False, on_epoch=True)
            self.plot(
                trajs,
                self.train_eval_batches,
                num_row=6,
                num_step=3,
                tag="cond_fm_train_samples_6_plots",
                shuffle=True,
            )
            self.plot(
                trajs,
                self.train_eval_batches,
                num_row=2,
                num_step=3,
                tag="cond_fm_train_samples_2_plots",
                shuffle=True,
            )
            self.train_eval_batches = []
            self.train_evals_count = 0
    
    def predict_step(self, batch, batch_idx):
        # use to return predictions for final plots
        if self.predict_count < 60:
            idx, x0, _ = batch[:3]
            trajs, pred, true = self.eval_batch(batch)
            self.predict_count += 1
            return idx, trajs, x0, pred, true
        else:
            pass

    def validation_step(self, batch, batch_idx):
        trajs, pred, true = self.eval_batch(batch, prefix='val')
        eval_metrics_ood = self.compute_eval_metrics(pred, true, aggregate=False)
        eval_metrics_mean_ood = {
            k: np.mean([m[k] for m in eval_metrics_ood]) for k in eval_metrics_ood[0]
        }
        for key, value in eval_metrics_mean_ood.items():
            self.log(f"val/{key}", value, on_step=False, on_epoch=True)
        self.plot(
            trajs, batch, num_row=6, num_step=3, tag="cond_fm_val_samples_6_plots", shuffle=True
        )
        self.plot(
            trajs, batch, num_row=2, num_step=3, tag="cond_fm_val_samples_2_plots", shuffle=True
        )

    def test_step(self, batch, batch_idx):
        trajs, pred, true = self.eval_batch(batch, prefix='test')
        eval_metrics_ood = self.compute_eval_metrics(pred, true, aggregate=False)
        eval_metrics_mean_ood = {
            k: np.mean([m[k] for m in eval_metrics_ood]) for k in eval_metrics_ood[0]
        }
        for key, value in eval_metrics_mean_ood.items():
            self.log(f"test/{key}", value, on_step=False, on_epoch=True)
        self.plot(
            trajs, batch, num_row=6, num_step=3, tag="cond_fm_test_samples_6_plots", shuffle=True
        )
        self.plot(
            trajs, batch, num_row=2, num_step=3, tag="cond_fm_test_samples_2_plots", shuffle=True
        )

    def compute_eval_metrics(self, pred_samples, target_samples, aggregate=True):
        metrics = []
        for i in range(target_samples.shape[0]):
            names, dd = compute_distribution_distances(
                pred_samples[i].unsqueeze(1).to(target_samples),
                target_samples[i].unsqueeze(1),
            )
            metrics.append({**dict(zip(names, dd))})
        if aggregate:
            metrics = {k: np.mean([m[k] for m in metrics]) for k in metrics[0]}
        return metrics

    def eval_batch(self, batch, prefix='train'):
        _, x0, x1, cond = batch
        
        if prefix != 'train':
            # average over train one-hot conditions
            cond = torch.cat(
                [
                    torch.ones(
                        (
                            cond.shape[0],
                            cond.shape[1],
                            self.num_conditions - 10 - 10,
                        )
                    ),
                    torch.zeros(
                        (
                            cond.shape[0],
                            cond.shape[1],
                            10 + 10,
                        )
                    ),
                ],
                dim=-1,
            ).to(x0) / (self.num_conditions - 10 - 10)
        
        node = NeuralODE(
            torch_wrapper_flow_cond(self.model),
            solver='dopri5', #rk4
            sensitivity="adjoint",
            atol=1e-4,
            rtol=1e-4,
        )

        time_span = torch.linspace(0, 1, self.integrate_time_steps)

        with torch.no_grad():
            if len(x0.shape) > 3:
                x0 = x0.squeeze()
                x1 = x1.squeeze()

            trajs, pred = [], []
            for i in range(x0.shape[0]):
                x0_i = x0[i]
                cond_i = cond[i]
                traj = node.trajectory(torch.cat((x0_i, cond_i), dim=-1).float(), t_span=time_span)
                trajs.append(traj[:, :, :self.model.D])
                pred.append(traj[-1, :, :self.model.D])

            trajs = torch.stack(trajs, dim=0)
            pred = torch.stack(pred, dim=0)
            true = x1
        return trajs, pred, true

    def plot(
        self, trajs, samples, num_row=6, num_step=5, tag="cond_fm_samples", shuffle=True
    ):
        if not isinstance(trajs, list):
            trajs = trajs.cpu().detach().numpy()
            _, source, target, _ = samples[:3]
        else:
            trajs_tmp = []
            for traj in trajs:
                trajs_tmp.append(traj.cpu().detach().numpy().squeeze(0))
            trajs = trajs_tmp

        if self.base == "source":
            num_col = 1 + num_step
        elif self.base == "gaussian":
            num_col = 2 + num_step
        else:
            raise ValueError(f"unknown base: {self.base}")

        fig, axs = plt.subplots(
            num_row,
            num_col,
            figsize=(num_col, num_row),
            gridspec_kw={"wspace": 0.0, "hspace": 0.0},
        )
        axs = axs.reshape(num_row, num_col)

        for i in range(num_row):
            ax = axs[i, 0]

            n = 750
            rng = np.random.default_rng(42)

            if not isinstance(trajs, list):
                idcs = rng.choice(
                    np.arange(source[i].shape[0]),
                    size=min(n, source[i].shape[0]),
                    replace=False,
                )
                source_samples = source[i].cpu().numpy()
                target_samples = target[i].cpu().numpy()
                source_samples = source_samples[idcs]
                target_samples = target_samples[idcs]
            else:
                source_samples = samples[i][1].squeeze(0).cpu().numpy()
                target_samples = samples[i][2].squeeze(0).cpu().numpy()
                idcs = rng.choice(
                    np.arange(source_samples.shape[0]),
                    size=min(n, source_samples.shape[0]),
                    replace=False,
                )
                source_samples = source_samples[idcs]
                target_samples = target_samples[idcs]

            ax.scatter(*source_samples.T, s=1, c="#3283FB", rasterized=True)
            ax.set_facecolor((206 / 256, 206 / 256, 229 / 256))

            ax = axs[i, -1]
            ax.scatter(*target_samples.T, s=1, c="#3283FB", rasterized=True)
            ax.set_facecolor((206 / 256, 206 / 256, 229 / 256))

            traj = trajs[i]

            t_step = int(traj.shape[0] / (num_step - 1))
            start_j = 1 if self.base == "source" else 0
            ts = np.arange(t_step, t_step * num_step, t_step)
            for j in range(start_j, num_step):
                t = ts[j - 1] - 1
                offset = 0 if self.base == "source" else 1
                ax = axs[i, j + offset]
                ax.scatter(*traj[t, idcs].T, s=1, c="#3283FB", rasterized=True)
                ax.set_facecolor((206 / 256, 206 / 256, 229 / 256))
                if i == 0:
                    time = t / (t_step * (num_step - 1))
                    ax.set_title(f"t={time:.2f}")

        axs[0, 0].set_title("source")
        axs[0, -1].set_title("target")

        for ax in axs.ravel():
            ax.set_xlim(-4, 4)
            ax.set_ylim(-4, 4)
            ax.set_xticks([])
            ax.set_yticks([])

        fig.tight_layout()
        fname = f"{sccfm_directory}/figs/{tag}.pdf"
        fig.savefig(
            fname, bbox_inches="tight", pad_inches=0.0, transparent=True, dpi=300
        )
        wandb.log({f"imgs/{tag}": wandb.Image(fig)})
        plt.close(fig)


class LettersGNNFM(pl.LightningModule):
    def __init__(
        self,
        flow_lr=1e-3,
        gnn_lr=5e-4,
        update_embedding_epochs_freq=20,
        update_embedding_epochs=10,
        dim=2,
        num_hidden=512,
        num_hidden_gnn=512,
        knn_k=100,
        base="source",
        integrate_time_steps=100,
        name="gnn_fm",
    ) -> None:
        super().__init__()

        # Important: This property controls manual optimization.
        self.automatic_optimization = False

        self.save_hyperparameters()

        #self.model = GlobalGNN(
        #    num_hidden_decoder=num_hidden, num_hidden_gnn=num_hidden_gnn, knn_k=knn_k
        #).cuda()
        self.model = GlobalGNN2(
            num_hidden_decoder=num_hidden, num_hidden_gnn=num_hidden_gnn, knn_k=knn_k
        ).cuda()

        assert len(list(self.model.parameters())) == len(
            list(self.model.decoder.parameters())
        ) + len(list(self.model.gcn_convs.parameters()))

        self.flow_lr = flow_lr
        self.gnn_lr = gnn_lr
        self.update_embedding_epochs_freq = update_embedding_epochs_freq
        self.update_embedding_epochs = update_embedding_epochs
        self.dim = dim
        self.knn_k = knn_k
        self.num_hidden = num_hidden
        self.integrate_time_steps = integrate_time_steps
        
        assert base in [
            "source",
            "gaussian",
        ], "Invalid base. Must be either 'source' or 'gaussian'"
        self.base = base
        self.name = name
        
        # for training data eval
        self.num_train_evals = 6
        self.train_evals_count = 0
        self.train_eval_batches = []
        self.embeddings = {}
        
        self.predict_count = 0

    def configure_optimizers(self):
        # create epoch list for gnn_step for training
        assert self.update_embedding_epochs_freq > self.update_embedding_epochs
        freq_epochs = [x for x in range(0, self.trainer.max_epochs, self.update_embedding_epochs_freq)]
        self.gnn_epochs = [
            i + x for x in freq_epochs[1:] for i in range(self.update_embedding_epochs)
        ]
        # init optimizers
        self.flow_optimizer = torch.optim.Adam(self.model.decoder.parameters(), lr=self.flow_lr)
        self.gnn_optimizer = torch.optim.Adam(
            self.model.gcn_convs.parameters(),
            lr=self.gnn_lr,
            #weight_decay=1e-5,
        )
        return self.flow_optimizer, self.gnn_optimizer

    def compute_loss(self, embedding, source_samples, target_samples):
        t = torch.rand_like(source_samples[..., 0, None])

        if self.base == "source":
            y = (1.0 - t) * source_samples + t * target_samples
            u = target_samples - source_samples

            b = self.model.flow(embedding, t.squeeze(-1), y)
            loss = b.norm(dim=-1) ** 2 - 2.0 * (b * u).sum(dim=-1)
        elif self.base == "gaussian":
            z = torch.randn_like(target_samples)
            y = (1.0 - t) * z + t * target_samples
            u = target_samples - z
            b = self.model.flow(embedding, t.squeeze(-1), y)
            loss = ((b - u) ** 2).sum(dim=-1)
        else:
            raise ValueError(f"unknown base: {self.base}")

        loss = loss.mean()
        return loss

    def get_embeddings(self, idx, source_samples):
        if idx.shape[0] > 1:
            for i in range(idx.shape[0]):
                if idx[i].item() in self.embeddings:
                    return self.embeddings[idx[i].item()]
                else:
                    embedding = self.model.embed_source(source_samples[i])
                    self.embeddings[idx[i].item()] = embedding.detach()
                    return embedding
        else:
            idx = idx.item()
            if idx in self.embeddings:
                return self.embeddings[idx]
            else:
                embedding = self.model.embed_source(source_samples)
                self.embeddings[idx] = embedding.detach()
                return embedding

    def flow_step(self, batch):
        idx, x0, x1 = batch
        embedding = self.get_embeddings(idx, x0.squeeze())
        loss = self.compute_loss(embedding, x0.squeeze(0), x1.squeeze(0))
        self.flow_optimizer.zero_grad()
        self.manual_backward(loss)
        #torch.nn.utils.clip_grad_norm_(self.model.decoder.parameters(), 1.0)
        self.flow_optimizer.step()
        return loss
    
    def gnn_step(self, batch):
        idx, x0, x1 = batch
        embedding = self.model.embed_source(x0.squeeze(0))
        self.embeddings[idx.item()] = embedding.detach()
        loss = self.compute_loss(embedding, x0.squeeze(0), x1.squeeze(0))
        self.gnn_optimizer.zero_grad()
        self.manual_backward(loss)
        #torch.nn.utils.clip_grad_norm_(self.model.gcn_convs.parameters(), 1.)
        self.gnn_optimizer.step()
        #self.gnn_lr_scheduler.step()
        return loss
        
    def training_step(self, batch, batch_idx):
        if self.current_epoch in self.gnn_epochs:
            #if self.current_epoch % self.update_embedding_epochs_freq == 0:
            #    warmup_steps = 5
            #    self.gnn_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            #        self.gnn_optimizer,
            #        start_factor=1e-3,
            #        end_factor=1.0,
            #        total_iters=warmup_steps,
            #    )
            #print("epoch:", self.current_epoch, 'lr:', self.gnn_lr_scheduler.get_last_lr())
            loss = self.gnn_step(batch)
        else:
            loss = self.flow_step(batch)
        self.log("train/loss", loss, on_step=False, on_epoch=True)
        if (
            (self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)) == 0
            and self.train_evals_count < self.num_train_evals
        ):
            self.train_eval_batches.append(batch) 
            self.train_evals_count += 1
        return loss
    
    def training_epoch_end(self, outputs):
        #if (self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)) == 0:
        if self.current_epoch == self.trainer.max_epochs - 1:
            trajs, preds, trues = [], [], []
            for batch in self.train_eval_batches:
                traj, pred, true = self.eval_batch(batch)
                trajs.append(traj)
                preds.append(pred)
                trues.append(true)
            
            eval_metrics = []
            for i in range(len(trues)):
                pred_samples, target_samples = preds[i], trues[i]
                names, dd = compute_distribution_distances(
                    pred_samples.squeeze(0).unsqueeze(1).to(target_samples),
                    target_samples.squeeze(0).unsqueeze(1),
                )
                eval_metrics.append({**dict(zip(names, dd))})
            eval_metrics = {
                k: np.mean([m[k] for m in eval_metrics]) for k in eval_metrics[0]
            }
            for key, value in eval_metrics.items():
                self.log(f"train/{key}", value, on_step=False, on_epoch=True)
            self.plot(
                trajs,
                self.train_eval_batches,
                num_row=6,
                num_step=3,
                tag="gnn_train_samples_6_plots",
                shuffle=True,
            )
            self.plot(
                trajs,
                self.train_eval_batches,
                num_row=2,
                num_step=3,
                tag="gnn_train_samples_2_plots",
                shuffle=True,
            )
            self.train_eval_batches = []
            self.train_evals_count = 0
            
    def predict_step(self, batch, batch_idx):
        # use to return predictions for final plots
        idx, x0, _ = batch
        if self.predict_count < 60:
            trajs, pred, true = self.eval_batch(batch)
            self.predict_count += 1
            return idx, trajs, x0, pred, true
        else:
            pass

    def validation_step(self, batch, batch_idx):
        trajs, pred, true = self.eval_batch(batch)
        eval_metrics_ood = self.compute_eval_metrics(pred, true, aggregate=False)
        eval_metrics_mean_ood = {
            k: np.mean([m[k] for m in eval_metrics_ood]) for k in eval_metrics_ood[0]
        }
        for key, value in eval_metrics_mean_ood.items():
            self.log(f"val/{key}", value, on_step=False, on_epoch=True)
        self.plot(trajs, batch, num_row=6, num_step=3, tag="gnn_val_samples_6_plots", shuffle=True)
        self.plot(trajs, batch, num_row=2, num_step=3, tag="gnn_val_samples_2_plots", shuffle=True)
        
    def on_validation_end(self):
        if not self.automatic_optimization:
            # Save a checkpoint of the model
            ckpt_path = os.path.join(self.trainer.log_dir, "checkpoints", "ckpt.pt")
            self.trainer.save_checkpoint(ckpt_path)
        return super().on_validation_end()

    def test_step(self, batch, batch_idx):
        trajs, pred, true = self.eval_batch(batch)
        eval_metrics_ood = self.compute_eval_metrics(pred, true, aggregate=False)
        eval_metrics_mean_ood = {
            k: np.mean([m[k] for m in eval_metrics_ood]) for k in eval_metrics_ood[0]
        }
        for key, value in eval_metrics_mean_ood.items():
            self.log(f"test/{key}", value, on_step=False, on_epoch=True)
        self.plot(trajs, batch, num_row=6, num_step=3, tag="gnn_test_samples_6_plots", shuffle=True)
        self.plot(trajs, batch, num_row=2, num_step=3, tag="gnn_test_samples_2_plots", shuffle=True)
    
    def compute_eval_metrics(self, pred_samples, target_samples, aggregate=True):
        metrics = []
        for i in range(target_samples.shape[0]):
            names, dd = compute_distribution_distances(
                pred_samples[i].unsqueeze(1).to(target_samples),
                target_samples[i].unsqueeze(1),
            )
            metrics.append({**dict(zip(names, dd))})
        if aggregate:
            metrics = {k: np.mean([m[k] for m in metrics]) for k in metrics[0]}
        return metrics

    def eval_batch(self, batch):
        _, x0, x1 = batch
        
        node = NeuralODE(
            torch_wrapper(self.model),
            solver='dopri5', #rk4
            sensitivity="adjoint",
            atol=1e-4,
            rtol=1e-4,
        )

        time_span = torch.linspace(0, 1, self.integrate_time_steps)
 
        with torch.no_grad():
            if len(x0.shape) > 3:
                x0 = x0.squeeze()
                x1 = x1.squeeze()

            trajs, pred = [], []
            for i in range(x0.shape[0]):
                x0_i = x0[i]
                self.model.update_embedding_for_inference(x0_i)
                traj = node.trajectory(x0_i, t_span=time_span)
                trajs.append(traj)
                pred.append(traj[-1, :, :])

            trajs = torch.stack(trajs, dim=0) 
            pred = torch.stack(pred, dim=0) 
            true = x1
        return trajs, pred, true

    def plot(self, trajs, samples, num_row=6, num_step=5, tag="gnn_samples", shuffle=True):
        if not isinstance(trajs, list): 
            trajs = trajs.cpu().detach().numpy()
            _, source, target = samples
        else:
            trajs_tmp = []
            for traj in trajs:
                trajs_tmp.append(traj.cpu().detach().numpy().squeeze(0))
            trajs = trajs_tmp

        if self.base == "source":
            num_col = 1 + num_step
        elif self.base == "gaussian":
            num_col = 2 + num_step
        else:
            raise ValueError(f"unknown base: {self.base}")

        fig, axs = plt.subplots(
            num_row,
            num_col,
            figsize=(num_col, num_row),
            gridspec_kw={"wspace": 0.0, "hspace": 0.0},
        )
        axs = axs.reshape(num_row, num_col)

        for i in range(num_row):
            ax = axs[i, 0]

            n = 750
            rng = np.random.default_rng(42)
            
            if not isinstance(trajs, list): 
                idcs = rng.choice(np.arange(source[i].shape[0]), size=min(n, source[i].shape[0]), replace=False)
                source_samples = source[i].cpu().numpy()
                target_samples = target[i].cpu().numpy()
                source_samples = source_samples[idcs]
                target_samples = target_samples[idcs]
            else:
                source_samples = samples[i][1].squeeze(0).cpu().numpy()
                target_samples = samples[i][2].squeeze(0).cpu().numpy()
                idcs = rng.choice(np.arange(source_samples.shape[0]), size=min(n, source_samples.shape[0]), replace=False)
                source_samples = source_samples[idcs]
                target_samples = target_samples[idcs]

            ax.scatter(*source_samples.T, s=1, c="#3283FB", rasterized=True)
            ax.set_facecolor((206/256, 206/256, 229/256))
            
            ax = axs[i, -1]
            ax.scatter(*target_samples.T, s=1, c="#3283FB", rasterized=True)
            ax.set_facecolor((206/256, 206/256, 229/256))

            traj = trajs[i]
            
            t_step = int(traj.shape[0] / (num_step - 1))
            start_j = 1 if self.base == "source" else 0
            ts = np.arange(t_step, t_step * num_step, t_step)
            for j in range(start_j, num_step):
                t = ts[j-1] - 1
                offset = 0 if self.base == "source" else 1
                ax = axs[i, j + offset]
                ax.scatter(*traj[t, idcs].T, s=1, c="#3283FB", rasterized=True)
                ax.set_facecolor((206/256, 206/256, 229/256))
                if i == 0:
                    time = t / (t_step*(num_step - 1))
                    ax.set_title(f"t={time:.2f}")

        axs[0, 0].set_title("source")
        axs[0, -1].set_title("target")

        for ax in axs.ravel():
            ax.set_xlim(-4, 4)
            ax.set_ylim(-4, 4)
            ax.set_xticks([])
            ax.set_yticks([])

        fig.tight_layout()
        fname = f"{sccfm_directory}/figs/{tag}_{self.knn_k}.pdf"
        os.makedirs(os.path.dirname(fname), exist_ok=True)
        fig.savefig(
            fname, bbox_inches="tight", pad_inches=0.0, transparent=True, dpi=300
        )
        wandb.log({f"imgs/{tag}_{self.knn_k}": wandb.Image(fig)})
        plt.close(fig)
        

class LettersContrastiveGNNFM(LettersGNNFM):
    def __init__(
        self,
        flow_lr=1e-3,
        gnn_lr=5e-4,
        update_embedding_epochs_freq=20,
        update_embedding_epochs=10,
        dim=2,
        num_hidden=512,
        num_hidden_gnn=512,
        knn_k=100,
        base="source",
        integrate_time_steps=100,
        name="contrast_gnn_fm",
    ) -> None:
        super().__init__(
            flow_lr=flow_lr, 
            gnn_lr=gnn_lr, 
            update_embedding_epochs_freq=update_embedding_epochs_freq, 
            update_embedding_epochs=update_embedding_epochs, 
            dim=dim, 
            num_hidden=num_hidden, 
            num_hidden_gnn=num_hidden_gnn, 
            knn_k=knn_k, 
            base=base, 
            integrate_time_steps=integrate_time_steps, 
            name=name
            )
    
    def configure_optimizers(self):
        # create epoch list for gnn_step for training
        assert self.update_embedding_epochs_freq > self.update_embedding_epochs
        freq_epochs = [
            x
            for x in range(
                0, self.trainer.max_epochs, self.update_embedding_epochs_freq
            )
        ]
        self.gnn_epochs = [
            i + x for x in freq_epochs for i in range(self.update_embedding_epochs)
        ]
        # init optimizers
        self.flow_optimizer = torch.optim.Adam(self.model.decoder.parameters(), lr=1e-3)
        self.gnn_optimizer = torch.optim.Adam(
            self.model.gcn_convs.parameters(), lr=5e-4
        )
        return self.flow_optimizer, self.gnn_optimizer

    def info_nce_loss(self, e_i, e_j):
        # TODO: expand to work for mor than batch size 2
        # Calculate cosine similarity
        cos_sim = F.cosine_similarity(e_i.unsqueeze(0), e_j.unsqueeze(0), dim=-1)
        # Mask out cosine similarity to itself
        #self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device).squeeze(-1)
        #cos_sim.masked_fill_(self_mask, -9e15)
        # Find positive example -> batch_size//2 away from the original example
        #pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0)
        # InfoNCE loss
        print(cos_sim)
        cos_sim = cos_sim.squeeze(-1) / 0.07 # self.hparams.temsperature
        print(cos_sim)
        #nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        print(torch.logsumexp(cos_sim, dim=-1))
        nll = -cos_sim + torch.logsumexp(cos_sim, dim=-1)
        print(nll)
        #nll = nll.mean()
        return nll

    def flow_step(self, batch):
        idx, x0, x1 = batch
        for j in range(idx.shape[0]):
            embedding = self.get_embeddings(idx[j], x0[j].squeeze())
            loss = self.compute_loss(embedding, x0[j].squeeze(0), x1[j].squeeze(0))
            self.flow_optimizer.zero_grad()
            self.manual_backward(loss)
            self.flow_optimizer.step()
        return loss

    def gnn_step(self, batch):
        idx, x0, x1 = batch
        assert idx.shape[0] == 2 # TODO: rn only works for batch size 2
        e_i = self.model.embed_source(x0[0].squeeze(0))
        self.embeddings[idx[0].item()] = e_i.detach()
        e_j = self.model.embed_source(x0[1].squeeze(0))
        self.embeddings[idx[1].item()] = e_j.detach()
        #loss = self.compute_loss(embedding, x0.squeeze(0), x1.squeeze(0))
        loss = self.info_nce_loss(e_i, e_j)
        print(loss)
        print(asdasd)
        self.gnn_optimizer.zero_grad()
        self.manual_backward(loss)
        self.gnn_optimizer.step()
        self.gnn_lr_scheduler.step()
        return loss
    
    def training_epoch_end(self, outputs):
        if (self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)) == 0:
            trajs, preds, trues = [], [], []
            for batch in self.train_eval_batches:
                traj, pred, true = self.eval_batch(batch)
                trajs.append(traj)
                preds.append(pred)
                trues.append(true)
            
            eval_metrics = []
            for i in range(len(trues)):
                pred_samples, target_samples = preds[i], trues[i]
                for j in range(pred_samples.shape[0]):
                    names, dd = compute_distribution_distances(
                        pred_samples[j].squeeze(0).unsqueeze(1).to(target_samples),
                        target_samples[j].squeeze(0).unsqueeze(1),
                    )
                    eval_metrics.append({**dict(zip(names, dd))})
            eval_metrics = {
                k: np.mean([m[k] for m in eval_metrics]) for k in eval_metrics[0]
            }
            for key, value in eval_metrics.items():
                self.log(f"train/{key}", value, on_step=False, on_epoch=True)
            self.plot(
                trajs,
                self.train_eval_batches,
                num_row=6,
                num_step=5,
                tag="contrast_gnn_train_samples",
                shuffle=True,
            )
            self.train_eval_batches = []
            self.train_evals_count = 0
            
    def validation_step(self, batch, batch_idx):
        trajs, pred, true = self.eval_batch(batch)
        eval_metrics_ood = self.compute_eval_metrics(pred, true, aggregate=False)
        eval_metrics_mean_ood = {
            k: np.mean([m[k] for m in eval_metrics_ood]) for k in eval_metrics_ood[0]
        }
        for key, value in eval_metrics_mean_ood.items():
            self.log(f"val/{key}", value, on_step=False, on_epoch=True)
        self.plot(
            trajs, batch, num_row=6, num_step=5, tag="contrast_gnn_val_samples", shuffle=True
        )

    def test_step(self, batch, batch_idx):
        trajs, pred, true = self.eval_batch(batch)
        eval_metrics_ood = self.compute_eval_metrics(pred, true, aggregate=False)
        eval_metrics_mean_ood = {
            k: np.mean([m[k] for m in eval_metrics_ood]) for k in eval_metrics_ood[0]
        }
        for key, value in eval_metrics_mean_ood.items():
            self.log(f"test/{key}", value, on_step=False, on_epoch=True)
        self.plot(
            trajs, batch, num_row=6, num_step=5, tag="contrast_gnn_test_samples", shuffle=True
        )

    def plot(
        self, trajs, samples, num_row=6, num_step=5, tag="contrast_gnn_samples", shuffle=True
    ):
        if not isinstance(trajs, list):
            trajs = trajs.cpu().detach().numpy()
            _, source, target = samples
        else:
            trajs_tmp = []
            for traj in trajs:
                for j in range(traj.shape[0]):
                    trajs_tmp.append(traj[j].cpu().detach().numpy())
            trajs = trajs_tmp

        if self.base == "source":
            num_col = 1 + num_step
        elif self.base == "gaussian":
            num_col = 2 + num_step
        else:
            raise ValueError(f"unknown base: {self.base}")

        fig, axs = plt.subplots(
            num_row,
            num_col,
            figsize=(num_col, num_row),
            gridspec_kw={"wspace": 0.0, "hspace": 0.0},
        )
        axs = axs.reshape(num_row, num_col)

        for i in range(num_row):
            ax = axs[i, 0]

            if not isinstance(trajs, list):
                source_samples = source[i].cpu().numpy()
                target_samples = target[i].cpu().numpy()
            else:
                source_samples = samples[i][1].squeeze(0).cpu().numpy()
                target_samples = samples[i][2].squeeze(0).cpu().numpy()

            print(source_samples.shape, target_samples.shape)
            ax.scatter(*source_samples.T, s=1, c="steelblue")

            ax = axs[i, -1]
            ax.scatter(*target_samples.T, s=1, c="steelblue")

            traj = trajs[i]

            t_step = int(traj.shape[0] / (num_step - 1))
            start_j = 1 if self.base == "source" else 0
            ts = np.arange(t_step, t_step * num_step, t_step)
            for j in range(start_j, num_step):
                t = ts[j - 1] - 1
                offset = 0 if self.base == "source" else 1
                ax = axs[i, j + offset]
                ax.scatter(*traj[t].T, s=1, c="steelblue")
                if i == 0:
                    time = t / (t_step * (num_step - 1))
                    ax.set_title(f"t={time:.2f}")

        axs[0, 0].set_title("source")
        axs[0, -1].set_title("target")

        for ax in axs.ravel():
            ax.set_xlim(-4, 4)
            ax.set_ylim(-4, 4)
            ax.set_xticks([])
            ax.set_yticks([])

        fig.tight_layout()
        fname = f"{sccfm_directory}/figs/{tag}.pdf"
        os.makedirs(os.path.dirname(fname), exist_ok=True)
        fig.savefig(
            fname, bbox_inches="tight", pad_inches=0.0, transparent=True, dpi=300
        )
        wandb.log({f"imgs/{tag}": wandb.Image(fig)})
        plt.close(fig)