"""Lightning module for trellis dataset."""

import os
import time
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import wandb

from typing import *
from util.ot_sampler import *
from torchdyn.core import NeuralODE
from util.distribution_distances import compute_distribution_distances
from src.models.components.mlp import (
    Flow, Flow2,
    torch_wrapper,
    torch_wrapper_flow_cond,
    torch_wrapper_gnn_flow_cond,
)
from src.models.components.gnn import GlobalGNN, GlobalGNN2

from pathlib import Path

script_directory = Path(__file__).absolute().parent
sccfm_directory = script_directory.parent.parent


class TrellisMLPFM(pl.LightningModule):
    def __init__(
        self,
        lr=1e-3,
        dim=43,
        num_hidden=512,
        num_layers=3,
        num_treat_conditions=11,
        base="source",
        integrate_time_steps=100,
        ivp_batch_size=1024,
        pca=None,
        pca_space=False,
        name="mlp_fm",
        seed=0,
    ) -> None:
        super().__init__()

        # Important: This property controls manual optimization.
        self.automatic_optimization = True

        self.save_hyperparameters()

        self.model = Flow2(
            D=dim,
            num_conditions=num_treat_conditions,
            num_hidden=num_hidden,
            num_layers=num_layers,
        ).cuda()
        self.lr = lr
        self.dim = dim
        self.num_hidden = num_hidden
        self.integrate_time_steps = integrate_time_steps
        self.ivp_batch_size = ivp_batch_size
        self.pca = pca # log if PCA is being used on data

        assert base in [
            "source",
            "gaussian",
        ], "Invalid base. Must be either 'source' or 'gaussian'"
        self.base = base
        self.name = name
        
        # eval cell batch rng
        self.rng = np.random.default_rng(seed)
        
        # for training data eval
        self.num_train_evals = 100 #6
        self.train_evals_count = 0
        self.train_eval_batches = []
        
        self.train_metrics = {"PDO": [], "PDOF": [], "F": []}
        self.val_metrics = {"PDO": [], "PDOF": [], "F": []}
        self.test_metrics = {"PDO": [], "PDOF": [], "F": []}

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        return optimizer

    def forward(self, t, x):
        return self.model(t.squeeze(-1), x)

    def compute_loss(self, source_samples, target_samples, treat_cond):
        t = torch.rand_like(source_samples[..., 0, None])

        if self.base == "source":
            x = (1.0 - t) * source_samples + t * target_samples
            u = target_samples - source_samples
            b = self.forward(t, torch.cat((x, treat_cond), dim=-1))
            loss = b.norm(dim=-1) ** 2 - 2.0 * (b * u).sum(dim=-1)
        elif self.base == "gaussian":
            z = torch.randn_like(target_samples)
            x = (1.0 - t) * z + t * target_samples
            u = target_samples - z
            b = self.forward(t, torch.cat((x, treat_cond), dim=-1))
            loss = ((b - u) ** 2).sum(dim=-1)
        else:
            raise ValueError(f"unknown base: {self.base}")

        loss = loss.mean()
        return loss

    def unpack_batch(self, batch, eval=False):
        idx, culture, x0, x1, x1_full, cell_cond, treat_cond = batch
        
        # cell batching for evaluation
        if eval:
            if x0.shape[1] > 5000 or x1.shape[1] > 5000:
                return (
                    idx,
                    culture,
                    x0[:, :5000, :] if x0.shape[1] > 5000 else x0,
                    x1[:, :5000, :] if x1.shape[1] > 5000 else x1,
                    x1_full[:, :5000, :] if x1.shape[1] > 5000 and self.pca is not None else x1_full,
                    cell_cond[:, :5000, :] if x0.shape[1] > 5000 else cell_cond,
                    treat_cond[:, :5000, :] if x0.shape[1] > 5000 else treat_cond,
                )                
            else:
                return idx, culture, x0, x1, x1_full, cell_cond, treat_cond
        
        # cell batching for training
        if self.ivp_batch_size is not None and (
            x0.shape[1] > self.ivp_batch_size and x1.shape[1] > self.ivp_batch_size
        ):
            x0_ivp_idcs = self.rng.choice(
                np.arange(x0.shape[1]), size=self.ivp_batch_size, replace=False
            )
            x1_ivp_idcs = self.rng.choice(
                np.arange(x1.shape[1]), size=self.ivp_batch_size, replace=False
            )
            x0 = x0[:, x0_ivp_idcs, :]
            x1 = x1[:, x1_ivp_idcs, :]
            x1_full = x1_full[:, x1_ivp_idcs, :] if self.pca is not None else None
            cell_cond = cell_cond[:, x0_ivp_idcs, :]
            treat_cond = treat_cond[:, x0_ivp_idcs, :]
            return idx, culture, x0, x1, x1_full, cell_cond, treat_cond
        elif x0.shape[1] != x1.shape[1]:
            ivp_batch_size = min(x0.shape[1], x1.shape[1])
            ivp_idcs = self.rng.choice(
                np.arange(ivp_batch_size), size=ivp_batch_size, replace=False
            )
            x0 = x0[:, ivp_idcs, :]
            x1 = x1[:, ivp_idcs, :]
            x1_full = x1_full[:, ivp_idcs, :] if self.pca is not None else None
            cell_cond = cell_cond[:, ivp_idcs, :]
            treat_cond = treat_cond[:, ivp_idcs, :]
            return idx, culture, x0, x1, x1_full, cell_cond, treat_cond
        else:
            x1_full = x1_full if self.pca is not None else None
            return idx, culture, x0, x1, x1_full, cell_cond, treat_cond

    def training_step(self, batch, batch_idx):
        _, _, x0, x1, _, _, treat_cond = self.unpack_batch(batch)
        assert (
            len(x0.shape) == 3
        ), "This was a temporary fix for the dataloader -- TODO: Make the code more gener."
        loss = self.compute_loss(x0.squeeze(0).float(), x1.squeeze(0).float(), treat_cond.squeeze(0).float())
        self.log("train/loss", loss, on_step=False, on_epoch=True)
        if (
            self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)
        ) == 0 and self.train_evals_count < self.num_train_evals:
            self.train_eval_batches.append(batch)
            self.train_evals_count += 1
        return loss

    def training_epoch_end(self, outputs):
        if self.current_epoch == self.trainer.max_epochs - 1:
            for batch in self.train_eval_batches:
                self.eval_batch(batch, prefix="train")
            for culture_key in list(self.train_metrics.keys()):                
                eval_metrics_mean_train = {
                    k: np.mean([m[k] for m in self.train_metrics[culture_key]])
                    for k in self.train_metrics[culture_key][0]
                }        
                for key, value in eval_metrics_mean_train.items():
                    self.log(f"train/{key}-{culture_key}", value, on_step=False, on_epoch=True)
            self.train_eval_batches = []
            self.train_evals_count = 0
            self.train_metrics = {"PDO": [], "PDOF": [], "F": []}

    def validation_step(self, batch, batch_idx):
        self.eval_batch(batch, prefix="val")

    def validation_epoch_end(self, outputs):
        for culture_key in list(self.val_metrics.keys()):
            if not self.val_metrics[culture_key]:
                continue
            eval_metrics_mean_ood = {
                k: np.mean([m[k] for m in self.val_metrics[culture_key]])
                for k in self.val_metrics[culture_key][0]
            }
            for key, value in eval_metrics_mean_ood.items():
                self.log(
                    f"val/{key}-{culture_key}", value, on_step=False, on_epoch=True
                )
                
        self.val_metrics = {"PDO": [], "PDOF": [], "F": []}

        if not self.automatic_optimization:
            # Save a checkpoint of the model
            ckpt_path = os.path.join(
                self.trainer.log_dir, "checkpoints", "ckpt.ckpt"
            )
            self.trainer.save_checkpoint(ckpt_path)
        return super().on_validation_end()

    def test_step(self, batch, batch_idx):
        self.eval_batch(batch, prefix="test")

    def test_epoch_end(self, outputs):
        for culture_key in list(self.test_metrics.keys()):
            if not self.test_metrics[culture_key]:
                continue
            eval_metrics_mean_ood = {
                k: np.mean([m[k] for m in self.test_metrics[culture_key]])
                for k in self.test_metrics[culture_key][0]
            }
            for key, value in eval_metrics_mean_ood.items():
                self.log(
                    f"test/{key}-{culture_key}", value, on_step=False, on_epoch=True
                )
        self.test_metrics = {"PDO": [], "PDOF": [], "F": []}

    def eval_batch(self, batch, prefix):
        _, culture, x0, x1, x1_full, _, treat_cond, = self.unpack_batch(batch, eval=True)
        
        node = NeuralODE(
            torch_wrapper_flow_cond(self.model),
            solver="dopri5",
            sensitivity="adjoint",
            atol=1e-4,
            rtol=1e-4,
        )

        time_span = torch.linspace(0, 1, self.integrate_time_steps)
        
        with torch.no_grad():
            if len(x0.shape) > 3:
                x0 = x0.squeeze()
                x1 = x1.squeeze()

            for i in range(
                x0.shape[0]
            ):  # for loop used to allow for replicate batching for eval
                x0_i = x0[i].float()
                treat_cond_i = treat_cond[i].float()
                pred_batches = []
                idcs_batches = np.arange(x0_i.shape[0])
                for j in range(0, x0_i.shape[0], self.ivp_batch_size):
                    idcs = idcs_batches[j : j + self.ivp_batch_size]
                    traj = node.trajectory(torch.cat((x0_i[idcs], treat_cond_i[idcs]), dim=-1), t_span=time_span)
                    pred_batches.append(traj[-1, :, : self.model.D])
                
                pred = torch.cat(pred_batches, dim=0)
                
                #pred = traj[-1, :, : self.model.D]
                if self.pca is not None and self.dim == 43:
                    pred = self.pca.inverse_transform(pred.cpu().numpy())
                    pred = torch.tensor(pred).cuda()
                    
                true = x1_full.float() if self.pca is not None and self.dim == 43 else x1.float()
                #true = x1.float()
                
                names, dd = compute_distribution_distances(
                    pred.unsqueeze(1).to(true),
                    true[0].unsqueeze(1),
                )
       
                if prefix == "train":
                    self.train_metrics[culture[0]].append({**dict(zip(names, dd))})
                elif prefix == "val":
                    self.val_metrics[culture[0]].append({**dict(zip(names, dd))})
                elif prefix == "test":
                    self.test_metrics[culture[0]].append({**dict(zip(names, dd))})
                else:
                    raise ValueError(f"unknown prefix: {prefix}")


class TrellisCondMLPFM(pl.LightningModule):
    def __init__(
        self,
        lr=1e-3,
        dim=43,
        num_hidden=512,
        num_layers=3,
        base="source",
        integrate_time_steps=100,
        ivp_batch_size=1024,
        num_exp_conditions=927,
        num_treat_conditions=11,
        pca=None,
        pca_space=False,
        name="mlp_fm",
        seed=0,
    ) -> None:
        super().__init__()

        # Important: This property controls manual optimization.
        self.automatic_optimization = True

        self.save_hyperparameters()

        self.num_exp_conditions = num_exp_conditions 
        self.num_treat_conditions = num_treat_conditions
        self.num_conditions = num_exp_conditions + num_treat_conditions
        self.model = Flow2(
            D=dim,
            num_hidden=num_hidden,
            num_layers=num_layers,
            num_conditions=num_exp_conditions,
            num_treat_conditions=num_treat_conditions,
        ).cuda()  # if using small expt data
        self.lr = lr
        self.dim = dim
        self.num_hidden = num_hidden
        self.integrate_time_steps = integrate_time_steps
        self.ivp_batch_size = ivp_batch_size
        self.pca = pca 

        assert base in [
            "source",
            "gaussian",
        ], "Invalid base. Must be either 'source' or 'gaussian'"
        self.base = base
        self.name = name
        
        # eval cell batch rng
        self.rng = np.random.default_rng(seed)

        # for training data eval
        self.num_train_evals = 100 #6
        self.train_evals_count = 0
        self.train_eval_batches = []

        self.train_metrics = {"PDO": [], "PDOF": [], "F": []}
        self.val_metrics = {"PDO": [], "PDOF": [], "F": []}
        self.test_metrics = {"PDO": [], "PDOF": [], "F": []}

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        return optimizer

    def forward(self, t, x):
        return self.model(t.squeeze(-1), x)

    def compute_loss(self, source_samples, target_samples, cond):
        t = torch.rand_like(source_samples[..., 0, None])

        if self.base == "source":
            x = (1.0 - t) * source_samples + t * target_samples
            u = target_samples - source_samples
            b = self.forward(t, torch.cat((x, cond), dim=-1).float())
            loss = b.norm(dim=-1) ** 2 - 2.0 * (b * u).sum(dim=-1)
        elif self.base == "gaussian":
            z = torch.randn_like(target_samples)
            x = (1.0 - t) * z + t * target_samples
            u = target_samples - z
            b = self.forward(t, torch.cat((x, cond), dim=-1).float())
            loss = ((b - u) ** 2).sum(dim=-1)
        else:
            raise ValueError(f"unknown base: {self.base}")

        loss = loss.mean()
        return loss

    def unpack_batch(self, batch, eval=False):
        idx, culture, x0, x1, x1_full, cell_cond, treat_cond = batch
        
        # cell batching for evaluation
        if eval:
            exp_cond = torch.nn.functional.one_hot(
                idx.long(), num_classes=self.num_exp_conditions
            )
            exp_cond = exp_cond.expand(x0.shape[0], x0.shape[1], -1)
            if x0.shape[1] > 5000 or x1.shape[1] > 5000:
                return (
                    idx,
                    culture,
                    x0[:, :5000, :] if x0.shape[1] > 5000 else x0,
                    x1[:, :5000, :] if x1.shape[1] > 5000 else x1,
                    x1_full[:, :5000, :]
                    if x1.shape[1] > 5000 and self.pca is not None
                    else x1_full,
                    cell_cond[:, :5000, :] if x0.shape[1] > 5000 else cell_cond,
                    treat_cond[:, :5000, :] if x0.shape[1] > 5000 else treat_cond,
                    exp_cond[:, :5000, :] if x0.shape[1] > 5000 else exp_cond,
                        
                )
            else:
                return idx, culture, x0, x1, x1_full, cell_cond, treat_cond, exp_cond

        # cell batching for training
        if self.ivp_batch_size is not None and (
            x0.shape[1] > self.ivp_batch_size and x1.shape[1] > self.ivp_batch_size
        ):
            x0_ivp_idcs = self.rng.choice(
                np.arange(x0.shape[1]), size=self.ivp_batch_size, replace=False
            )
            x1_ivp_idcs = self.rng.choice(
                np.arange(x1.shape[1]), size=self.ivp_batch_size, replace=False
            )
            x0 = x0[:, x0_ivp_idcs, :]
            x1 = x1[:, x1_ivp_idcs, :]
            x1_full = x1_full[:, x1_ivp_idcs, :] if self.pca is not None else None
            cell_cond = cell_cond[:, x0_ivp_idcs, :]
            treat_cond = treat_cond[:, x0_ivp_idcs, :]
            exp_cond = torch.nn.functional.one_hot(
                idx.long(), num_classes=self.num_exp_conditions
            )
            exp_cond = exp_cond.expand(x0.shape[0], x0.shape[1], -1)
            return idx, culture, x0, x1, x1_full, cell_cond, treat_cond, exp_cond
        elif x0.shape[1] != x1.shape[1]:
            ivp_batch_size = min(x0.shape[1], x1.shape[1])
            ivp_idcs = self.rng.choice(
                np.arange(ivp_batch_size), size=ivp_batch_size, replace=False
            )
            x0 = x0[:, ivp_idcs, :]
            x1 = x1[:, ivp_idcs, :]
            x1_full = x1_full[:, ivp_idcs, :] if self.pca is not None else None
            cell_cond = cell_cond[:, ivp_idcs, :]
            treat_cond = treat_cond[:, ivp_idcs, :]
            exp_cond = torch.nn.functional.one_hot(
                idx.long(), num_classes=self.num_exp_conditions
            )
            exp_cond = exp_cond.expand(x0.shape[0], x0.shape[1], -1)
            return idx, culture, x0, x1, x1_full, cell_cond, treat_cond, exp_cond
        else:
            x1_full = x1_full if self.pca is not None else None
            exp_cond = torch.nn.functional.one_hot(
                idx.long(), num_classes=self.num_exp_conditions
            )
            exp_cond = exp_cond.expand(x0.shape[0], x0.shape[1], -1)
            return idx, culture, x0, x1, x1_full, cell_cond, treat_cond, exp_cond
        

    def OLD_unpack_batch(self, batch):
        idx, culture, x0, x1, cell_cond, treat_cond = batch
        if self.ivp_batch_size is not None and (
            x0.shape[1] > self.ivp_batch_size and x1.shape[1] > self.ivp_batch_size
        ):
            x0_ivp_idcs = np.random.choice(
                np.arange(x0.shape[1]), size=self.ivp_batch_size, replace=False
            )
            x1_ivp_idcs = np.random.choice(
                np.arange(x1.shape[1]), size=self.ivp_batch_size, replace=False
            )
            x0 = x0[:, x0_ivp_idcs, :]
            x1 = x1[:, x1_ivp_idcs, :]
            cell_cond = cell_cond[:, x0_ivp_idcs, :]
            treat_cond = treat_cond[:, x0_ivp_idcs, :]
            exp_cond = torch.nn.functional.one_hot(
                idx.long(), num_classes=self.num_exp_conditions
            )  
            exp_cond = exp_cond.expand(x0.shape[0], x0.shape[1], -1)
            return idx, culture, x0, x1, cell_cond, treat_cond, exp_cond
        elif x0.shape[1] != x1.shape[1]:
            ivp_batch_size = min(x0.shape[1], x1.shape[1])
            ivp_idcs = np.random.choice(
                np.arange(ivp_batch_size), size=ivp_batch_size, replace=False
            )
            x0 = x0[:, ivp_idcs, :]
            x1 = x1[:, ivp_idcs, :]
            cell_cond = cell_cond[:, ivp_idcs, :]
            treat_cond = treat_cond[:, ivp_idcs, :]
            exp_cond = torch.nn.functional.one_hot(
                idx.long(), num_classes=self.num_exp_conditions
            )
            exp_cond = exp_cond.expand(x0.shape[0], x0.shape[1], -1)
            return idx, culture, x0, x1, cell_cond, treat_cond, exp_cond
        else:
            exp_cond = torch.nn.functional.one_hot(
                idx.long(), num_classes=self.num_exp_conditions
            )
            exp_cond = exp_cond.expand(x0.shape[0], x0.shape[1], -1)
            return idx, culture, x0, x1, cell_cond, treat_cond, exp_cond 

    def training_step(self, batch, batch_idx):
        idx, _, x0, x1, _, _, treat_cond, exp_cond = self.unpack_batch(batch)
        assert (
            len(x0.shape) == 3
        ), "This was a temporary fix for the dataloader -- TODO: Make the code more gener."
        cond = torch.cat((exp_cond, treat_cond), dim=-1)
        loss = self.compute_loss(
            x0.squeeze(0).float(), x1.squeeze(0).float(), cond.squeeze(0).float()
        )
        self.log("train/loss", loss, on_step=False, on_epoch=True)
        if (
            self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)
        ) == 0 and self.train_evals_count < self.num_train_evals:
            self.train_eval_batches.append(batch)
            self.train_evals_count += 1
        return loss

    def training_epoch_end(self, outputs):
        if self.current_epoch == self.trainer.max_epochs - 1:
            for batch in self.train_eval_batches:
                self.eval_batch(batch, prefix="train")
            for culture_key in list(self.train_metrics.keys()):
                eval_metrics_mean_train = {
                    k: np.mean([m[k] for m in self.train_metrics[culture_key]])
                    for k in self.train_metrics[culture_key][0]
                }
                for key, value in eval_metrics_mean_train.items():
                    self.log(f"train/{key}-{culture_key}", value, on_step=False, on_epoch=True)
            self.train_eval_batches = []
            self.train_evals_count = 0
            self.train_metrics = {"PDO": [], "PDOF": [], "F": []}

    def validation_step(self, batch, batch_idx):
        self.eval_batch(batch, prefix="val")

    def validation_epoch_end(self, outputs):
        for culture_key in list(self.val_metrics.keys()):
            if not self.val_metrics[culture_key]:
                continue
            eval_metrics_mean_ood = {
                k: np.mean([m[k] for m in self.val_metrics[culture_key]])
                for k in self.val_metrics[culture_key][0]
            }
            for key, value in eval_metrics_mean_ood.items():
                self.log(
                    f"val/{key}-{culture_key}", value, on_step=False, on_epoch=True
                )

        self.val_metrics = {"PDO": [], "PDOF": [], "F": []}

        if not self.automatic_optimization:
            # Save a checkpoint of the model
            ckpt_path = os.path.join(self.trainer.log_dir, "checkpoints", "ckpt.ckpt")
            self.trainer.save_checkpoint(ckpt_path)
        return super().on_validation_end()

    def test_step(self, batch, batch_idx):
        self.eval_batch(batch, prefix="test")

    def test_epoch_end(self, outputs):
        for culture_key in list(self.test_metrics.keys()):
            if not self.test_metrics[culture_key]:
                continue
            eval_metrics_mean_ood = {
                k: np.mean([m[k] for m in self.test_metrics[culture_key]])
                for k in self.test_metrics[culture_key][0]
            }
            for key, value in eval_metrics_mean_ood.items():
                self.log(
                    f"test/{key}-{culture_key}", value, on_step=False, on_epoch=True
                )
        self.test_metrics = {"PDO": [], "PDOF": [], "F": []}

    def eval_batch(self, batch, prefix):
        if prefix == 'train':
            idx, culture, x0, x1, x1_full, _, treat_cond, exp_cond = self.unpack_batch(batch, eval=True)
        else:
            idx, culture, x0, x1, x1_full, _, treat_cond, _ = self.unpack_batch(batch, eval=True)
            
            if prefix == 'val' or prefix == 'test':
                num_train_conditions = self.num_exp_conditions - 33 - 33  # for replica split use: - 111 - 103. TODO: make this a hparam.
                exp_cond = (
                    torch.ones((x0.shape[0], x0.shape[1], num_train_conditions)).cuda()
                    / num_train_conditions
                )
                exp_cond = torch.cat(
                    (
                        exp_cond,
                        torch.zeros(
                            (
                                x0.shape[0],
                                x0.shape[1],
                                33 + 33 # for replica split use: 111 + 103. TODO: make this a hparam.
                            )
                        ).cuda(),
                    ),
                    dim=-1,
                )

        node = NeuralODE(
            torch_wrapper_flow_cond(self.model),
            solver="dopri5",
            sensitivity="adjoint",
            atol=1e-4,
            rtol=1e-4,
        )

        time_span = torch.linspace(0, 1, self.integrate_time_steps)

        with torch.no_grad():
            if len(x0.shape) > 3:
                x0 = x0.squeeze()
                x1 = x1.squeeze()
                
            for i in range(x0.shape[0]):  # for loop used to allow for replicate batching for eval
                x0_i = x0[i].float()
                exp_cond_i = exp_cond[i]
                treat_cond_i = treat_cond[i].float()
                pred_batches = []
                idcs_batches = np.arange(x0_i.shape[0])
                for j in range(0, x0_i.shape[0], 1024): # smaller batch size for eval to fit in memory
                    idcs = idcs_batches[j : j + 1024]
                    traj = node.trajectory(
                        torch.cat((x0_i[idcs], exp_cond_i[idcs], treat_cond_i[idcs]), dim=-1),
                        t_span=time_span,
                    )
                    pred_batches.append(traj[-1, :, : self.model.D])

                pred = torch.cat(pred_batches, dim=0)

                #pred = traj[-1, :, : self.model.D]
                if self.pca is not None and self.dim == 43:
                    pred = self.pca.inverse_transform(pred.cpu().numpy())
                    pred = torch.tensor(pred).cuda()
                    
                true = x1_full.float() if self.pca is not None and self.dim == 43 else x1.float()
                #true = x1.float()

                names, dd = compute_distribution_distances(
                    pred.unsqueeze(1).to(true),
                    true[0].unsqueeze(1),
                )

                if prefix == "train":
                    self.train_metrics[culture[0]].append({**dict(zip(names, dd))})
                elif prefix == "val":
                    self.val_metrics[culture[0]].append({**dict(zip(names, dd))})
                elif prefix == "test":
                    self.test_metrics[culture[0]].append({**dict(zip(names, dd))})
                else:
                    raise ValueError(f"unknown prefix: {prefix}")


class TrellisGNNFM(pl.LightningModule):
    def __init__(
        self,
        flow_lr=1e-3,
        gnn_lr=5e-4,
        update_embedding_epochs_freq=20,
        update_embedding_epochs=10,
        dim=43,
        num_hidden=512,
        num_layers_decoder=3,
        num_hidden_gnn=512,
        knn_k=10,
        num_treat_conditions=None,
        num_cell_conditions=None,
        base="source",
        ivp_batch_size=None,
        integrate_time_steps=100,
        pca=None,
        pca_space=False,
        name="gnn_fm",
        seed=0,
    ) -> None:
        super().__init__()

        # Important: This property controls manual optimization.
        self.automatic_optimization = False

        self.save_hyperparameters()

        self.model = GlobalGNN2(
            D=dim,
            num_hidden_decoder=num_hidden,
            num_layers_decoder=num_layers_decoder,
            num_hidden_gnn=num_hidden_gnn,
            knn_k=knn_k,
            num_treat_conditions=num_treat_conditions,
            num_cell_conditions=num_cell_conditions,
        ).cuda()
        
        assert len(list(self.model.parameters())) == len(
            list(self.model.decoder.parameters())
        ) + len(list(self.model.gcn_convs.parameters()))

        self.flow_lr = flow_lr
        self.gnn_lr = gnn_lr
        self.update_embedding_epochs_freq = update_embedding_epochs_freq
        self.update_embedding_epochs = update_embedding_epochs
        self.dim = dim
        self.num_hidden = num_hidden
        self.ivp_batch_size = ivp_batch_size
        self.integrate_time_steps = integrate_time_steps
        self.pca = pca
                
        assert base in [
            "source",
            "gaussian",
        ], "Invalid base. Must be either 'source' or 'gaussian'"
        self.base = base
        self.name = name
        
        # eval cell batch rng
        self.rng = np.random.default_rng(seed)

        # for training data eval
        self.num_train_evals = 100 #6
        self.train_evals_count = 0
        self.train_eval_batches = []
        self.embeddings = {}
        
        self.train_metrics = {"PDO": [], "PDOF": [], "F": []}
        self.val_metrics = {"PDO": [], "PDOF": [], "F": []}
        self.test_metrics = {"PDO": [], "PDOF": [], "F": []}

    def configure_optimizers(self):
        # create epoch list for gnn_step for training
        assert self.update_embedding_epochs_freq > self.update_embedding_epochs
        freq_epochs = [x for x in range(0, self.trainer.max_epochs, self.update_embedding_epochs_freq)]
        self.gnn_epochs = [
            i + x for x in freq_epochs[1:] for i in range(self.update_embedding_epochs)
        ]
        # init optimizers
        self.flow_optimizer = torch.optim.Adam(self.model.decoder.parameters(), lr=self.flow_lr)
        self.gnn_optimizer = torch.optim.Adam(
            self.model.gcn_convs.parameters(),
            lr=self.gnn_lr,
            #weight_decay=1e-5,
        )
        return self.flow_optimizer, self.gnn_optimizer

    def compute_loss(self, embedding, source_samples, target_samples, treat_cond):
        t = torch.rand_like(source_samples[..., 0, None])

        if self.base == "source":
            y = (1.0 - t) * source_samples + t * target_samples
            u = target_samples - source_samples

            b = self.model.flow(
                embedding, t.squeeze(-1), torch.cat((y, treat_cond), dim=-1)
            )
            loss = b.norm(dim=-1) ** 2 - 2.0 * (b * u).sum(dim=-1)
        elif self.base == "gaussian":
            z = torch.randn_like(target_samples)
            y = (1.0 - t) * z + t * target_samples
            u = target_samples - z
            b = self.model.flow(
                embedding, t.squeeze(-1), torch.cat((y, treat_cond), dim=-1)
            )
            loss = ((b - u) ** 2).sum(dim=-1)
        else:
            raise ValueError(f"unknown base: {self.base}")

        loss = loss.mean()
        return loss

    def get_embeddings(self, idx, source_samples, cond=None):
        if idx.shape[0] > 1:
            for i in range(idx.shape[0]):
                if idx[i].item() in self.embeddings:
                    return self.embeddings[idx[i].item()]
                else:
                    embedding = self.model.embed_source(source_samples[i], cond=cond)
                    self.embeddings[idx[i].item()] = embedding.detach()
                    return embedding
        else:
            idx = idx.item()
            if idx in self.embeddings:
                return self.embeddings[idx]
            else:
                embedding = self.model.embed_source(source_samples, cond=cond)
                self.embeddings[idx] = embedding.detach()
                return embedding
            
    def unpack_batch(self, batch, eval=False):
        idx, culture, x0, x1, x1_full, cell_cond, treat_cond = batch

        # cell batching for evaluation
        if eval:
            if x0.shape[1] > 5000 or x1.shape[1] > 5000:
                return (
                    idx,
                    culture,
                    x0[:, :5000, :] if x0.shape[1] > 5000 else x0,
                    x1[:, :5000, :] if x1.shape[1] > 5000 else x1,
                    x1_full[:, :5000, :]
                    if x1.shape[1] > 5000 and self.pca is not None
                    else x1_full,
                    cell_cond[:, :5000, :] if x0.shape[1] > 5000 else cell_cond,
                    treat_cond[:, :5000, :] if x0.shape[1] > 5000 else treat_cond,
                )
            else:
                return idx, culture, x0, x1, x1_full, cell_cond, treat_cond

        # cell batching for training
        if self.ivp_batch_size is not None and (
            x0.shape[1] > self.ivp_batch_size and x1.shape[1] > self.ivp_batch_size
        ):
            x0_ivp_idcs = self.rng.choice(
                np.arange(x0.shape[1]), size=self.ivp_batch_size, replace=False
            )
            x1_ivp_idcs = self.rng.choice(
                np.arange(x1.shape[1]), size=self.ivp_batch_size, replace=False
            )
            x0 = x0[:, x0_ivp_idcs, :]
            x1 = x1[:, x1_ivp_idcs, :]
            x1_full = x1_full[:, x1_ivp_idcs, :] if self.pca is not None else None
            cell_cond = cell_cond[:, x0_ivp_idcs, :]
            treat_cond = treat_cond[:, x0_ivp_idcs, :]
            return idx, culture, x0, x1, x1_full, cell_cond, treat_cond
        elif x0.shape[1] != x1.shape[1]:
            ivp_batch_size = min(x0.shape[1], x1.shape[1])
            ivp_idcs = self.rng.choice(
                np.arange(ivp_batch_size), size=ivp_batch_size, replace=False
            )
            x0 = x0[:, ivp_idcs, :]
            x1 = x1[:, ivp_idcs, :]
            x1_full = x1_full[:, ivp_idcs, :] if self.pca is not None else None
            cell_cond = cell_cond[:, ivp_idcs, :]
            treat_cond = treat_cond[:, ivp_idcs, :]
            return idx, culture, x0, x1, x1_full, cell_cond, treat_cond
        else:
            x1_full = x1_full if self.pca is not None else None
            return idx, culture, x0, x1, x1_full, cell_cond, treat_cond

    def flow_step(self, batch):
        idx, _, x0, x1, _, cell_cond, treat_cond = self.unpack_batch(batch)
        embedding = self.get_embeddings(
            idx, x0.float().squeeze(), cell_cond.float().squeeze()
        )
        loss = self.compute_loss(
            embedding, x0.float().squeeze(0), x1.float().squeeze(0), treat_cond.float().squeeze(0)
        )
        self.flow_optimizer.zero_grad()
        self.manual_backward(loss)
        self.flow_optimizer.step()
        return loss
    
    def gnn_step(self, batch):
        idx, _, x0, x1, _, cell_cond, treat_cond = self.unpack_batch(batch)
        embedding = self.model.embed_source(x0.float().squeeze(0), cond=cell_cond.float().squeeze(0))
        self.embeddings[idx.item()] = embedding.detach()
        loss = self.compute_loss(embedding, x0.float().squeeze(0), x1.float().squeeze(0), treat_cond.float().squeeze(0))
        self.gnn_optimizer.zero_grad()
        self.manual_backward(loss)
        self.gnn_optimizer.step()
        return loss
        
    def training_step(self, batch, batch_idx):
        if self.current_epoch in self.gnn_epochs:
            loss = self.gnn_step(batch)
        else:
            loss = self.flow_step(batch)
        self.log("train/loss", loss, on_step=False, on_epoch=True)
        if (
            (self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)) == 0
            and self.train_evals_count < self.num_train_evals
        ):
            self.train_eval_batches.append(batch) 
            self.train_evals_count += 1
        return loss
    
    def training_epoch_end(self, outputs):
        if self.current_epoch == self.trainer.max_epochs - 1:
        #if (self.current_epoch % (self.trainer.check_val_every_n_epoch - 1)) == 0:
        #if (
        #    self.current_epoch % (self.trainer.max_epochs - 1)
        #) == 0 and self.current_epoch > 0:
            for batch in self.train_eval_batches:
                self.eval_batch(batch, prefix="train")
            for culture_key in list(self.train_metrics.keys()):
                eval_metrics_mean_train = {
                    k: np.mean([m[k] for m in self.train_metrics[culture_key]])
                    for k in self.train_metrics[culture_key][0]
                }
                for key, value in eval_metrics_mean_train.items():
                    self.log(f"train/{key}-{culture_key}", value, on_step=False, on_epoch=True)
            self.train_eval_batches = []
            self.train_evals_count = 0
            self.train_metrics = {"PDO": [], "PDOF": [], "F": []}

    def validation_step(self, batch, batch_idx):
        self.eval_batch(batch, prefix='val')
            
    def validation_epoch_end(self, outputs):
        for culture_key in list(self.val_metrics.keys()):
            if not self.val_metrics[culture_key]:
                continue
            eval_metrics_mean_ood = {
                k: np.mean([m[k] for m in self.val_metrics[culture_key]])
                for k in self.val_metrics[culture_key][0]
            }
            for key, value in eval_metrics_mean_ood.items():
                self.log(
                    f"val/{key}-{culture_key}", value, on_step=False, on_epoch=True
                )
                
        self.val_metrics = {"PDO": [], "PDOF": [], "F": []}

        if not self.automatic_optimization:
            # Save a checkpoint of the model
            ckpt_path = os.path.join(self.trainer.log_dir, "checkpoints", "ckpt.ckpt")
            self.trainer.save_checkpoint(ckpt_path)
        return super().on_validation_end()
    
    def test_step(self, batch, batch_idx):
        self.eval_batch(batch, prefix="test")
        
    def test_epoch_end(self, outputs):
        for culture_key in list(self.test_metrics.keys()):
            if not self.test_metrics[culture_key]:
                continue
            eval_metrics_mean_ood = {
                k: np.mean([m[k] for m in self.test_metrics[culture_key]])
                for k in self.test_metrics[culture_key][0]
            }
            for key, value in eval_metrics_mean_ood.items():
                self.log(
                    f"test/{key}-{culture_key}", value, on_step=False, on_epoch=True
                )
        self.test_metrics = {"PDO": [], "PDOF": [], "F": []}

    def eval_batch(self, batch, prefix):
        idx, culture, x0, x1, x1_full, cell_cond, treat_cond, = self.unpack_batch(batch, eval=True)
        
        node = NeuralODE(
            torch_wrapper_gnn_flow_cond(self.model),
            solver="dopri5",
            sensitivity="adjoint",
            atol=1e-4,
            rtol=1e-4,
        )

        time_span = torch.linspace(0, 1, self.integrate_time_steps)

        with torch.no_grad():
            if len(x0.shape) > 3:
                x0 = x0.squeeze()
                x1 = x1.squeeze()

            for i in range(x0.shape[0]): # for loop used to allow for replicate batching for eval                    
                x0_i = x0[i].float()
                treat_cond_i = treat_cond[i].float()
                cell_cond_i = cell_cond[i].float()
                
                self.model.update_embedding_for_inference(
                    x0_i, cond=cell_cond_i
                )
                
                pred_batches = []
                idcs_batches = np.arange(x0_i.shape[0])
                for j in range(0, x0_i.shape[0], self.ivp_batch_size):
                    idcs = idcs_batches[j : j + self.ivp_batch_size]
                    traj = node.trajectory(
                        torch.cat((x0_i[idcs], treat_cond_i[idcs]), dim=-1).float(), t_span=time_span
                    )
                    pred_batches.append(traj[-1, :, : self.model.D])

                pred = torch.cat(pred_batches, dim=0)

                #pred = traj[-1, :, : self.model.D]
                if self.pca is not None and self.dim == 43:
                    pred = self.pca.inverse_transform(pred.cpu().numpy())
                    pred = torch.tensor(pred).cuda()
                    
                true = x1_full.float() if self.pca is not None and self.dim == 43 else x1.float()
                #true = x1.float()

                names, dd = compute_distribution_distances(
                    pred.unsqueeze(1).to(true),
                    true[0].unsqueeze(1),
                )
                
                if prefix == 'train':
                    self.train_metrics[culture[0]].append({**dict(zip(names, dd))})
                elif prefix == 'val':
                    self.val_metrics[culture[0]].append({**dict(zip(names, dd))})
                elif prefix == 'test':
                    self.test_metrics[culture[0]].append({**dict(zip(names, dd))})
                else:
                    raise ValueError(f"unknown prefix: {prefix}")
