from __future__ import annotations

import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tifffile as tiff
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from dataset import CellDataModule, to_rgb
from diffusers.models import AutoencoderKL
from metrics_utils import calculate_metrics_from_scratch
from models.sit import SiT_models
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from tqdm import tqdm
from train import generate_perturbation_matched_samples
from utils import load_encoders


def compare_columns(
    ref_cols: List[str],
    my_cols: List[str],
    *,
    channel_tokens=(
        "green",
        "red",
        "blue",  #   image channels
        "actin",
        "dna",
        "er",
        "golgi",
        "mito",
        "rna",
        "yellow",
    ),
    object_tokens=("cells", "cytoplasm", "nuclei"),  #   object suffixes
) -> Tuple[List[str], List[str], List[str]]:
    """
    Compare *ref_cols* (e.g. CellProfiler pipeline columns) with *my_cols* after
    normalising names so that **only the measurement itself remains**:
        • the compartment part before "|" (e.g. "Cytoplasm|") is discarded
        • trailing ",Cells" is removed
        • underscore‑delimited channel tokens (Green, Mito …) are discarded
        • trailing object type tokens (Cells, Cytoplasm, Nuclei) are discarded

    Parameters
    ----------
    ref_cols, my_cols : list[str]
        Column names to compare.
    channel_tokens    : iterable[str]
        Tokens to throw away when found as separate '_' parts.
    object_tokens     : iterable[str]
        Tokens to drop *only if they appear as the last '_' part.

    Returns
    -------
    matched : list[str]   – columns present in both lists (after normalisation)
    missing : list[str]   – columns in *ref_cols* but not in *my_cols*
    extra   : list[str]   – columns in *my_cols*  but not in *ref_cols*
    """

    chan_set = {t.lower() for t in channel_tokens}
    object_set = {t.lower() for t in object_tokens}

    # -------- normalisation helper -----------------------------------------
    def normalise(col: str) -> str:
        col = col.strip()

        # strip trailing ",Cells"  (any case)
        if col.lower().endswith(",cells"):
            col = col[:-6]

        # throw away everything before "|"  (compartment)
        if "|" in col:
            col = col.split("|", 1)[1]

        parts = col.split("_")

        # drop any channel token wherever it occurs
        parts = [p for p in parts if p.lower() not in chan_set]

        # if the *last* part is an object token (Cells, Cytoplasm, …) erase it
        if parts and parts[-1].lower() in object_set:
            parts = parts[:-1]

        return "_".join(parts).lower()  # lower‑cased for comparison

    # -----------------------------------------------------------------------
    norm_ref = defaultdict(list)
    norm_my = defaultdict(list)

    for c in ref_cols:
        norm_ref[normalise(c)].append(c)
    for c in my_cols:
        norm_my[normalise(c)].append(c)

    keys_ref, keys_my = set(norm_ref), set(norm_my)

    matched_keys = keys_ref & keys_my
    missing_keys = keys_ref - keys_my
    extra_keys = keys_my - keys_ref

    # keep *all* original column names corresponding to a key
    matched = [col for k in matched_keys for col in norm_ref[k]]
    missing = [col for k in missing_keys for col in norm_ref[k]]
    extra = [col for k in extra_keys for col in norm_my[k]]

    return matched, missing, extra


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
)
LOGGER = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------

CHANNEL_NAMES = [
    "DNA",  # Hoechst
    "ER",
    "Actin",
    "Golgi",
    "Mito",
    "RNA",
]

DEFAULT_CHANNELS = (0, 1, 2, 3, 4, 5)
RGB_MAP = {
    0: {"rgb": np.array([19, 0, 249]), "range": [0, 51]},
    1: {"rgb": np.array([42, 255, 31]), "range": [0, 107]},
    2: {"rgb": np.array([255, 0, 25]), "range": [0, 64]},
    3: {"rgb": np.array([45, 255, 252]), "range": [0, 191]},
    4: {"rgb": np.array([250, 0, 253]), "range": [0, 89]},
    5: {"rgb": np.array([254, 255, 40]), "range": [0, 191]},
}


def rescale_intensity(
    arr: torch.Tensor, bounds=(0.5, 99.5), out_range=(0.0, 1.0)
) -> torch.Tensor:
    arr = arr.float() / 255
    sample = arr.flatten()[::100]
    percentiles = torch.quantile(
        sample, torch.tensor([bounds[0] / 100.0, bounds[1] / 100.0])
    )
    arr = torch.clamp(arr, percentiles[0], percentiles[1])
    arr = (arr - percentiles[0]) / (percentiles[1] - percentiles[0])
    arr = arr * (out_range[1] - out_range[0]) + out_range[0]
    return arr


def to_rgb(img: torch.Tensor, dtype=torch.float32) -> torch.Tensor:  # type: ignore[no-untyped-def]
    """Reference: https://github.com/recursionpharma/rxrx1-utils/blob/d34b2b0db0af1cb4fe357573bb8de76bd042b34f/rxrx/io.py#L61
    Converts a typical photosynthetic image into RGB, assuming the 6-channel cellpaint standard.
    NOTE: to run on CUDA, use the float16 dtype (default assumption); to run on CPU, requires float32.
    """
    num_channels_required = 6
    b, num_channels, length, width = img.shape  # b x c x l x w
    prepped_img = torch.zeros(
        b, num_channels_required, length, width, dtype=img.dtype, device=img.device
    )
    if num_channels < num_channels_required:
        prepped_img[:, :num_channels, :, :] += img
    elif num_channels > num_channels_required:
        prepped_img += img[:, :num_channels_required, :, :]
    else:
        prepped_img += img
    # color mapping
    red = [1, 0, 0]
    green = [0, 1, 0]
    blue = [0, 0, 1]
    yellow = [1, 1, 0]
    magenta = [1, 0, 1]
    cyan = [0, 1, 1]
    rgb_map = torch.tensor(
        [blue, green, red, cyan, magenta, yellow],
        dtype=dtype,
        device=prepped_img.device,
    )
    rgb_img: torch.FloatTensor = (
        torch.einsum(  # type: ignore[assignment]
            "nchw,ct->nthw",
            prepped_img.to(dtype=dtype),
            rgb_map,
        )
        / 3.0
    )
    _max = rgb_img.max().item()
    _min = rgb_img.min().item()
    return rescale_intensity(rgb_img, bounds=(0.1, 99.9))


def save_channels_to_tiff(npy_path: Path, work_dir: Path, prefix: str) -> List[Path]:
    """Load a 6‑channel ``.npy`` and write six single‑plane TIFFs."""
    arr = np.load(npy_path)
    if arr.ndim != 3:
        raise ValueError(f"{npy_path} should have 3 dimensions, got {arr.shape}")

    # Standardise to (6, H, W)
    if arr.shape[0] == 6:
        arr = arr  # (6, H, W)
    elif arr.shape[-1] == 6:
        arr = np.transpose(arr, (2, 0, 1))  # (6, H, W)
    else:
        raise ValueError(
            f"{npy_path} has incompatible shape {arr.shape}; expected 6 channels"
        )

    # Debug print to understand value range
    LOGGER.info(f"Array shape: {arr.shape}, dtype: {arr.dtype}")
    LOGGER.info(f"Array min: {arr.min()}, max: {arr.max()}")

    work_dir.mkdir(parents=True, exist_ok=True)
    out_paths: List[Path] = []
    for idx, plane in enumerate(arr):
        # Scale from [0, 1] float to [0, 65535] uint16
        # First ensure values are in [0, 1] by clipping
        plane_norm = np.clip(plane, 0, 1)
        # Convert to float32 before scaling to avoid float16 overflow
        plane_norm = plane_norm.astype(np.float32)
        # Then scale to full 16-bit range
        plane_u16 = (plane_norm * 65535).astype(np.uint16)

        # Log min/max to verify proper scaling
        LOGGER.info(
            f"Channel {idx} ({CHANNEL_NAMES[idx]}): min={plane.min():.4f}, max={plane.max():.4f}, scaled min={plane_u16.min()}, scaled max={plane_u16.max()}"
        )

        fname = f"{prefix}_{CHANNEL_NAMES[idx]}_{idx}.tiff"
        fpath = work_dir / fname
        tiff.imwrite(fpath, plane_u16)
        out_paths.append(fpath)
        LOGGER.debug("Saved %s", fpath)
    return out_paths


def save_channel_as_rgb(channel_data, channel_idx, output_path, colormap=None):
    """
    Save a single channel as an RGB PNG image using a specified colormap.

    Parameters
    ----------
    channel_data : numpy.ndarray
        The single channel image data (2D array)
    channel_idx : int
        The channel index (0-5) used to determine default colormap if none provided
    output_path : Path or str
        Path where to save the output PNG
    colormap : numpy.ndarray, optional
        RGB colormap (r,g,b) to use. If None, uses the default from RGB_MAP
    """
    # Normalize to 0-1 range
    channel_norm = np.clip(channel_data, 0, 1)
    channel_norm = channel_norm.astype(np.float32)

    # Use provided colormap or default from RGB_MAP
    if colormap is None and channel_idx in RGB_MAP:
        # Get RGB color from the map
        rgb_color = RGB_MAP[channel_idx]["rgb"] / 255.0
    else:
        # Default to white if colormap not specified
        rgb_color = np.array([1.0, 1.0, 1.0])

    # Create RGB image (H, W, 3)
    rgb_img = np.zeros(
        (channel_norm.shape[0], channel_norm.shape[1], 3), dtype=np.float32
    )

    # Apply the color to the normalized channel
    for i in range(3):
        rgb_img[:, :, i] = channel_norm * rgb_color[i]

    # Convert to 8-bit for saving
    rgb_img = (rgb_img * 255).astype(np.uint8)

    # Save as PNG using PIL instead of cv2
    Image.fromarray(rgb_img).save(output_path)
    return output_path


def save_channels_to_png(npy_path: Path, work_dir: Path, prefix: str) -> List[Path]:
    """Load a 6‑channel ``.npy`` and write six RGB PNGs using the channel-specific colormaps."""
    arr = np.load(npy_path)
    if arr.ndim != 3:
        raise ValueError(f"{npy_path} should have 3 dimensions, got {arr.shape}")

    # Standardise to (6, H, W)
    if arr.shape[0] == 6:
        arr = arr  # (6, H, W)
    elif arr.shape[-1] == 6:
        arr = np.transpose(arr, (2, 0, 1))  # (6, H, W)
    else:
        raise ValueError(
            f"{npy_path} has incompatible shape {arr.shape}; expected 6 channels"
        )

    # Debug print to understand value range
    LOGGER.info(f"Array shape: {arr.shape}, dtype: {arr.dtype}")
    LOGGER.info(f"Array min: {arr.min()}, max: {arr.max()}")

    work_dir.mkdir(parents=True, exist_ok=True)
    out_paths: List[Path] = []
    for idx, plane in enumerate(arr):
        # Ensure we only process existing channels
        if idx >= len(CHANNEL_NAMES):
            continue

        # Normalize to 0-1 range
        plane_norm = np.clip(plane, 0, 1)

        # Save as RGB PNG with the channel-specific colormap
        fname = f"{prefix}_{CHANNEL_NAMES[idx]}_{idx}.png"
        fpath = work_dir / fname
        save_channel_as_rgb(plane_norm, idx, fpath)

        # Log min/max to verify normalization
        LOGGER.info(
            f"Channel {idx} ({CHANNEL_NAMES[idx]}): min={plane.min():.4f}, max={plane.max():.4f}"
        )

        out_paths.append(fpath)
        LOGGER.debug("Saved %s", fpath)
    return out_paths


def create_rgb_composite(channels, output_path):
    """
    Create an RGB composite image from multichannel data.

    Parameters
    ----------
    channels : numpy.ndarray
        The multichannel image data (C, H, W)
    output_path : Path or str
        Path where to save the RGB composite PNG
    """
    # Normalize channels to 0-1
    channels_norm = np.clip(channels, 0, 1).astype(np.float32)

    # Create RGB image (H, W, 3)
    h, w = channels_norm.shape[1], channels_norm.shape[2]
    rgb_img = np.zeros((h, w, 3), dtype=np.float32)

    # Create a weighted composite using RGB mapping
    # DNA (blue)
    rgb_img[:, :, 2] += channels_norm[0] * (RGB_MAP[0]["rgb"][0] / 255.0)

    # ER (green)
    rgb_img[:, :, 1] += channels_norm[1] * (RGB_MAP[1]["rgb"][1] / 255.0)

    # Actin (red)
    rgb_img[:, :, 0] += channels_norm[2] * (RGB_MAP[2]["rgb"][0] / 255.0)

    # Add a bit of the other channels with reduced intensity
    for ch in range(3, 6):
        color = RGB_MAP[ch]["rgb"] / 255.0
        for i in range(3):
            rgb_img[:, :, i] += (
                channels_norm[ch] * color[i] * 0.5
            )  # 50% intensity for remaining channels

    # Normalize the composite
    rgb_max = np.max(rgb_img)
    if rgb_max > 0:
        rgb_img = rgb_img / rgb_max

    # Convert to 8-bit for saving
    rgb_img = (rgb_img * 255).astype(np.uint8)

    # Save as PNG
    Image.fromarray(rgb_img).save(output_path)
    return output_path


def process_dataset_images(dataset, output_dir, max_images=None):
    """
    Process images from a dataset and save per-channel and RGB composite images.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        The dataset containing images with shape (6, 512, 512)
    output_dir : str or Path
        Root directory to save images
    max_images : int, optional
        Maximum number of images to process. None for all images.
    """
    # Create output directories
    output_dir = Path(output_dir)
    rgb_dir = output_dir / "rgb_composites"
    channels_dir = output_dir / "channel_images"

    rgb_dir.mkdir(parents=True, exist_ok=True)
    channels_dir.mkdir(parents=True, exist_ok=True)

    # Determine how many images to process
    num_images = len(dataset)
    if max_images is not None:
        num_images = min(num_images, max_images)

    LOGGER.info(f"Processing {num_images} images from dataset")

    # Process each image
    for i in range(num_images):
        # Get the image from the dataset
        # Assumes real_filtered_dataset[i][0] returns the image
        img = dataset[i][0]  # Shape (6, 512, 512)

        # Convert to numpy if it's a tensor
        if isinstance(img, torch.Tensor):
            img = img.cpu().numpy()

        # Create a folder for this image
        img_dir = channels_dir / f"img_{i}"
        img_dir.mkdir(exist_ok=True)

        # Save each channel as colored PNG
        for ch in range(img.shape[0]):
            if ch >= len(CHANNEL_NAMES):
                continue

            channel_path = img_dir / f"{CHANNEL_NAMES[ch]}.png"
            save_channel_as_rgb(img[ch], ch, channel_path)
            LOGGER.info(
                f"Saved channel {ch} ({CHANNEL_NAMES[ch]}) for image {i} to {channel_path}"
            )

        # Create and save RGB composite
        rgb_path = rgb_dir / f"rgb_composite_{i}.png"
        create_rgb_composite(img, rgb_path)
        LOGGER.info(f"Saved RGB composite for image {i} to {rgb_path}")

    return rgb_dir, channels_dir


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Process images from real_filtered_dataset"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./cell_images",
        help="Directory where to save the processed images",
    )
    parser.add_argument(
        "--max_images",
        type=int,
        default=None,
        help="Maximum number of images to process (default: all)",
    )
    parser.add_argument(
        "--perturbation_id",
        type=int,
        default=1138,
        help="Perturbation ID to filter (default: 1138)",
    )
    parser.add_argument(
        "--cell_type_id",
        type=int,
        default=0,
        help="Cell type ID to filter (default: 0)",
    )
    args = parser.parse_args()

    # Load configuration and datamodule
    filename = "diffusion_sit_full.yaml"
    print(f"Loading configuration from {filename}")
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    # Get the perturbation and cell type IDs
    pert_id = args.perturbation_id
    cell_type_id = args.cell_type_id

    print(
        f"Filtering dataset with perturbation_id={pert_id}, cell_type_id={cell_type_id}"
    )
    real_filtered_dataset = datamodule.filter_samples(
        perturbation_id=pert_id, cell_type_id=cell_type_id
    )

    print(f"Filtered dataset contains {len(real_filtered_dataset)} images")

    # Create output directory based on perturbation and cell type
    output_dir = Path(args.output_dir) / f"p{pert_id}_c{cell_type_id}"
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Created output directory: {output_dir}")

    # Process the images
    rgb_dir, channels_dir = process_dataset_images(
        real_filtered_dataset, output_dir, args.max_images
    )

    print(f"\nProcessing complete! Output saved to:")
    print(f"  - RGB composites: {rgb_dir}")
    print(f"  - Individual channels: {channels_dir}")
