# rohban_openphenom_demo.py
"""
Tiny evaluation script:
  • loads the MorphoDiff‑12 (+CONTROL) subset of the Rohban dataset
  • pushes 5‑channel crops through OpenPhenom’s CA‑MAE (encoder‑decoder)
  • logs reconstruction MSE and writes RGB comparisons

Run with:  python rohban_openphenom_demo.py
"""

import os
from collections import Counter
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from dataset import rescale_intensity, to_rgb
from omegaconf import OmegaConf
from openphenom import OpenPhenomEncoder  # pip install maes_microscopy
from rohbandatamodule import RohbanDataModule  # ← the DataModule we built
from torchvision.utils import save_image


# -----------------------------------------------------------------------------
def rohban_to_6ch(batch_5ch: torch.Tensor) -> torch.Tensor:
    """
    Map Rohban’s (DNA, Mito, AGP, ER, RNA) → RxRx1‑style 6‑channel stack:

        w1  nuclei / DNA     ← DNA
        w2  ER               ← ER
        w3  actin            ← AGP
        w4  nucleoli / RNA   ← RNA
        w5  mitochondria     ← Mito
        w6  Golgi            ← AGP  (duplicate)

    batch_5ch: (B,5,H,W) float32/16 in [0,1]
    returns   : (B,6,H,W)
    """
    dna, mito, agp, er, rna = batch_5ch.split(1, dim=1)
    return torch.cat([dna, er, agp, rna, mito, agp], dim=1)


def main():
    # ------------------------------------------------------------------ #
    # 1.  Data‑module configuration
    # ------------------------------------------------------------------ #
    cfg = OmegaConf.create(
        {
            "data_dir": "/mnt/pvc/AutoSync/data/cpg0017/broad/workspace",
            "mode": "morphdiff_exp_12",     # or "full", "morphdiff_exp_5"
            "resize": 512,                  # 256 is OpenPhenom’s native crop
            "keep_controls": True,
            "one_control_class": True,
            "collapse_variants": True,
            "holdout_ratio": 0.0,           # use the whole set for a quick scan
            "batch_size": 8,
            "num_workers": 4,
            "seed": 42,
            "shuffle": True,
        }
    )

    dm = RohbanDataModule(cfg)
    loader = dm.get_train_loader()

    # ------------------------------------------------------------------ #
    # 2.  Model
    # ------------------------------------------------------------------ #
    # 3. Instantiate model, loss, and optimizer
    model = OpenPhenomEncoder()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # The model from hub is likely pretrained, so we can just evaluate
    model.eval()

    # ------------------------------------------------------------------ #
    # 3.  Evaluation loop
    # ------------------------------------------------------------------ #
    out_dir = Path("ophenom_check_output")
    out_dir.mkdir(exist_ok=True)

    total_loss, n_batches = 0.0, 0
    import torch.nn as nn
    instance_norm = nn.InstanceNorm2d(6, affine=False, eps=1e-6)

    with torch.no_grad():
        for step, (imgs, _, _) in enumerate(loader):
            imgs = imgs.to(device)
            imgs = rohban_to_6ch(imgs)
            recons = model.encode_decode(imgs)          # same shape
            imgs = instance_norm(imgs)

            rgb_in  = to_rgb(imgs.cpu())
            rgb_out = to_rgb(recons.cpu())

            # loss = F.mse_loss(rgb_out, rgb_in)
            # loss = F.mse_loss(recons[:, :5, :, :], imgs[:, :5, :, :])  # MSE in 5 channels
            loss = F.mse_loss(recons[:, :, :, :], imgs[:, :, :, :])  # MSE in 5 channels
            total_loss += loss.item()
            n_batches += 1
            print(f"[{step:03d}]  batch MSE={loss.item():.5f}")

            if step % 20 == 0:
                comparison = torch.cat([rgb_in[:4], rgb_out[:4]])
                save_image(
                    comparison,
                    out_dir / f"recon_step{step:03d}.png",
                    nrow=4,
                )

            # ---- shorten demo ----
            if step == 60:
                break

    print(f"\nAverage RGB‑space MSE over {n_batches} batches: {total_loss/n_batches:.5f}")


if __name__ == "__main__":
    main()
