import lightning as L
import torch
import time
import wandb
from typing import Union, Optional
from tqdm import tqdm
from src.models.fm.r3n_fm import R3NFlowMatcher
from src.models.ot_samplers import BaseSampler


class DynamicsCFM(L.LightningModule):
    def __init__(
        self,
        fm: R3NFlowMatcher,
        velocity_net: torch.nn.Module,
        structure_net: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        ot_sampler: BaseSampler,
        ode_solver: torch.nn.Module,
        t_distribution: dict,
        lr_scheduler: Union[torch.optim.lr_scheduler, None] = None,
    ):
        super().__init__()

        self.save_hyperparameters()
        self.velocity_net = velocity_net
        self.structure_net = structure_net
        self.fm = fm
        self.t_distribution = t_distribution
        self.ot_sampler = ot_sampler
        self.ode_solver = ode_solver(velocity_model=self.velocity_net, fm=self.fm)

    def get_velocity_field(self, x0, x1):
        return x1 - x0

    def sample_t(self, shape, device: torch.device):
        if self.t_distribution.name == "uniform":
            t_max = self.t_distribution.p2
            return torch.rand(shape, device=device) * t_max
        elif self.t_distribution.name == "logit-normal":
            mean = self.t_distribution.p1
            std = self.t_distribution.p2
            noise = torch.randn(shape, device=device) * std + mean
            return torch.nn.functional.sigmoid(noise)
        elif self.t_distribution.name == "beta":
            p1 = self.t_distribution.p1
            p2 = self.t_distribution.p2
            dist = torch.distributions.beta.Beta(p1, p2)
            return dist.sample(shape).to(device)
        elif self.t_distribution.name == "mix_up02_beta":
            p1 = self.t_distribution.p1
            p2 = self.t_distribution.p2
            dist = torch.distributions.beta.Beta(p1, p2)
            samples_beta = dist.sample(shape).to(device)
            samples_uniform = torch.rand(shape, device=device)
            u = torch.rand(shape, device=device)
            return torch.where(u < 0.02, samples_uniform, samples_beta)
        else:
            raise NotImplementedError(
                f"Sampling mode for t {self.t_distribution.name} " f"not implemented."
            )

    def model_step(self, batch: dict):
        if "mask" not in batch:
            mask = torch.ones(batch["x0"].shape[:2], dtype=torch.bool).to(
                batch["x0"].device
            )
        else:
            mask = batch["mask"]

        x0 = self.fm._mask_and_zero_com(batch["x0"], mask=mask)
        xt = self.fm._mask_and_zero_com(batch["xt"], mask=mask)

        t = self.sample_t(xt.shape[:-2], device=xt.device)

        x_noise = self.fm.sample_noise(
            n=xt.shape[1],
            b=xt.shape[0],
            device=xt.device,
            mask=mask,
        )

        rest_conditions = {
            key: value for key, value in batch.items() if key not in ["xt", "x0"]
        }

        if self.ot_sampler.modifies_target_indices:
            x_noise, xt, j_indices = self.ot_sampler.transport(x_noise, xt)
            t = t[j_indices]
            x0 = x0[j_indices]
            mask = mask[j_indices]
            for key in rest_conditions:
                if rest_conditions[key] is not None:
                    rest_conditions[key] = rest_conditions[key][j_indices]
        else:
            x_noise, xt, _ = self.ot_sampler.transport(x_noise, xt)

        x, mu_t, eps_t = self.fm.interpolate(x0=x_noise, x1=xt, t=t, mask=mask)
        ut = self.get_velocity_field(x0=x_noise, x1=xt)

        x0_emb = self.structure_net(
            {
                "x_t": x0,
                "mask": mask,
                **rest_conditions,
            }
        )["out_feat"]

        nn_out = self.velocity_net(
            {
                "x_t": x,
                "t": t,
                "mask": mask,
                "x0": x0_emb,
                **rest_conditions,
            }
        )

        vt = nn_out["coors_pred"]

        vt = self.fm._mask_and_zero_com(vt, mask=mask)
        loss = self.compute_fm_loss(vt, ut, t, mask)
        return {
            "loss": loss,
        }

    def training_step(self, batch: dict, batch_idx: int):
        result = self.model_step(batch)
        total_batches = len(self.trainer.train_dataloader)
        progress = batch_idx / total_batches
        self.log("train/progress", progress, on_step=True, prog_bar=False)

        self.log_dict(
            {"train/loss": result["loss"]},
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            batch_size=batch["x0"].shape[0],
            sync_dist=True,
        )
        return result["loss"]

    def sample(
        self,
        x0: torch.Tensor,
        ode_steps: int = 20,
        ode_method: str = "rk4",
        mask: Optional[torch.Tensor] = None,
        **rest_conditions,
    ) -> torch.Tensor:
        if mask is None:
            mask = torch.ones(x0.shape[0], x0.shape[1]).long().bool().to(self.device)
        x0 = self.fm._mask_and_zero_com(x0, mask=mask)
        t_span = torch.linspace(0, 1, ode_steps + 1).to(self.device)
        x_noise = self.fm.sample_noise(
            n=x0.shape[1], b=x0.shape[0], device=self.device, mask=mask
        )

        with torch.no_grad():
            x0_emb = self.structure_net(
                {
                    "x_t": x0,
                    "mask": mask,
                    **rest_conditions,
                }
            )["out_feat"]
            extra_kwargs = {
                "mask": mask,
                "x0": x0_emb,
                **rest_conditions,
            }
            x_pred = self.ode_solver.sample(
                x_init=x_noise,
                time_grid=t_span,
                method=ode_method,
                compute_divergence=False,
                exact_divergence=False,
                return_intermediates=False,
                **extra_kwargs,
            )
            x_pred = self.fm._mask_and_zero_com(x_pred, mask=mask)
            x_pred = x_pred.view(x_pred.shape[0], -1, 3)

        return x_pred

    def compute_fm_loss(
        self,
        vt: torch.Tensor,
        ut: torch.Tensor,
        t: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        nres = torch.sum(mask, dim=-1) * 3
        err = (vt - ut) * mask[..., None]
        loss = torch.sum(err**2, dim=(-1, -2)) / nres
        return loss.mean()

    def generate_trajectory(
        self,
        x0: torch.Tensor,
        trajectory_steps: int,
        ode_steps: int = 20,
        ode_method: str = "rk4",
        mask: Optional[torch.Tensor] = None,
        return_intermediates: bool = False,
        **rest_conditions,
    ):
        x0 = self.fm._mask_and_zero_com(x0, mask=mask)
        all_frames = [x0]

        for _ in tqdm(
            range(trajectory_steps), desc="Generating trajectory", unit="step"
        ):
            start_time = time.perf_counter()
            x_t = self.sample(
                x0=x0,
                **rest_conditions,
                ode_steps=ode_steps,
                ode_method=ode_method,
                mask=mask,
            )
            elapsed_time = time.perf_counter() - start_time
            samples_per_sec = x0.shape[0] / max(elapsed_time, 1e-8)
            wandb.log(
                {
                    "generation/samples_per_sec": samples_per_sec,
                },
            )
            if return_intermediates:
                all_frames.append(x_t)
            x0 = x_t.clone()
        if return_intermediates:
            return torch.stack(all_frames, dim=1)
        else:
            return x_t

    def load_state_dict(self, state_dict, strict=True):
        if self.structure_net.finetune == 0:
            return super().load_state_dict(state_dict, strict=False)
        return super().load_state_dict(state_dict, strict=strict)

    def on_save_checkpoint(self, checkpoint):
        if self.structure_net.finetune == 0:
            # Avoid saving cond_nn weights each time since they are frozen
            # and take up a lot of space
            keys_to_remove = [
                k for k in checkpoint["state_dict"] if k.startswith("structure_net")
            ]
            for k in keys_to_remove:
                del checkpoint["state_dict"][k]

    def configure_optimizers(self):
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.lr_scheduler is None:
            return {"optimizer": optimizer}

        scheduler = self.hparams.lr_scheduler(optimizer=optimizer)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "train/loss",
            },
        }
