import glob
import os
import random
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule
from sc_perturb.openphenom import OpenPhenomEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tqdm import tqdm

warnings.filterwarnings("ignore")


def get_perturbation_cell_type_distribution_custom(datamodule):
    """
    Custom function to get the joint distribution of perturbation IDs and cell types in the training set.
    """
    train_metadata = datamodule.metadata[datamodule.metadata["train_index"]]

    # Define cell type mapping
    cell_type_to_label = {
        "HEPG2": 0,
        "HUVEC": 1,
        "RPE": 2,
        "U2OS": 3,
    }

    if "cell_type_id" in train_metadata.columns:
        grouped = train_metadata.groupby(["sirna_id", "cell_type_id"]).size()
        distribution = {
            (int(idx[0]), int(idx[1])): count for idx, count in grouped.items()
        }
    elif "cell_type" in train_metadata.columns:
        train_metadata = train_metadata.copy()
        train_metadata["cell_type_id"] = train_metadata["cell_type"].map(
            cell_type_to_label
        )
        grouped = train_metadata.groupby(["sirna_id", "cell_type_id"]).size()
        distribution = {
            (int(idx[0]), int(idx[1])): count for idx, count in grouped.items()
        }
    else:
        raise ValueError(
            "Neither 'cell_type_id' nor 'cell_type' column found in metadata"
        )

    return distribution


def get_top_perturbations_for_cell_type(datamodule, cell_type_id=0, top_k=10):
    """
    Get the top K perturbation IDs with the most samples for a specific cell type.
    """
    distribution = get_perturbation_cell_type_distribution_custom(datamodule)

    cell_type_counts = {}
    for (pert_id, ct_id), count in distribution.items():
        if ct_id == cell_type_id:
            cell_type_counts[pert_id] = count

    sorted_perturbations = sorted(
        cell_type_counts.items(), key=lambda x: x[1], reverse=True
    )

    # Filter out the excluded class (1116) and take the top_k
    top_perturbations = [
        pert_id for pert_id, count in sorted_perturbations if pert_id != 1116
    ][:top_k]

    print(f"Top {top_k} perturbations for cell_type_id {cell_type_id}:")
    for i, (pert_id, count) in enumerate(sorted_perturbations[:top_k]):
        print(f"  {i+1}. Perturbation {pert_id}: {count} samples")

    return top_perturbations


def find_generated_files_by_perturbation_and_celltype(
    generated_path, perturbation_id, cell_type_id
):
    """
    Find all generated numpy files for a specific perturbation ID and cell type ID.
    """
    pert_folder = f"p{perturbation_id}"
    pert_path = os.path.join(generated_path, pert_folder)

    if not os.path.exists(pert_path):
        return []

    # Pattern to match cell type in filenames (p<pid>_c<cell_type_id>_sample<sample_id>.npy)
    pattern = f"_c{cell_type_id}_sample"

    # Find all .npy files in the perturbation folder
    npy_files = glob.glob(os.path.join(pert_path, "*.npy"))

    # Filter files that match both the perturbation ID and cell type pattern
    filtered_files = [f for f in npy_files if pattern in f]

    return filtered_files


def save_image_as_png(image_tensor, save_path):
    """
    Save a single image tensor as PNG file.
    """
    # Convert tensor to numpy if needed
    if isinstance(image_tensor, torch.Tensor):
        image_np = image_tensor.cpu().numpy()
    else:
        image_np = image_tensor

    # Ensure image is in the right format (H, W, C) and range [0, 1]
    if image_np.ndim == 3:
        if image_np.shape[0] <= 6:  # Channels first format
            image_np = np.transpose(image_np, (1, 2, 0))

    # Normalize to [0, 1] if not already
    if image_np.max() > 1.0:
        image_np = image_np / 255.0

    # Convert to RGB if multi-channel
    if image_np.shape[2] > 3:
        # Use first 3 channels for RGB visualization
        image_np = image_np[:, :, :3]
    elif image_np.shape[2] == 1:
        # Convert grayscale to RGB
        image_np = np.repeat(image_np, 3, axis=2)

    # Ensure values are in [0, 1]
    image_np = np.clip(image_np, 0, 1)

    # Convert to uint8 and save
    image_uint8 = (image_np * 255).astype(np.uint8)

    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # Save using PIL
    from PIL import Image

    if image_uint8.shape[2] == 3:
        img = Image.fromarray(image_uint8, "RGB")
    else:
        img = Image.fromarray(image_uint8.squeeze(), "L")
    img.save(save_path)


def extract_features_and_save_images(
    datamodule,
    generated_path,
    perturbation_ids,
    cell_type_id,
    encoder,
    device,
    output_dir,
    max_samples_per_pert=100,
    batch_size=64,
):
    """
    Load real and generated data, extract features, and save images organized by perturbation.
    """
    all_real_features = []
    all_real_pert_labels = []
    all_generated_features = []
    all_generated_pert_labels = []

    # Create directories for saving images
    real_images_dir = os.path.join(output_dir, "real_images")
    generated_images_dir = os.path.join(output_dir, "generated_images")
    os.makedirs(real_images_dir, exist_ok=True)
    os.makedirs(generated_images_dir, exist_ok=True)

    for pert_id in tqdm(perturbation_ids, desc="Processing perturbations"):
        print(f"\nProcessing perturbation ID: {pert_id}")

        # Create perturbation-specific directories
        real_pert_dir = os.path.join(real_images_dir, f"perturbation_{pert_id}")
        gen_pert_dir = os.path.join(generated_images_dir, f"perturbation_{pert_id}")
        os.makedirs(real_pert_dir, exist_ok=True)
        os.makedirs(gen_pert_dir, exist_ok=True)

        # Load real data for specific cell type
        real_filtered_dataset = datamodule.filter_samples(
            perturbation_id=pert_id,
            cell_type_id=cell_type_id,
        )

        if real_filtered_dataset is not None and len(real_filtered_dataset) > 0:
            real_images = [
                real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
            ]

            # Limit samples if needed
            if len(real_images) > max_samples_per_pert:
                indices = random.sample(range(len(real_images)), max_samples_per_pert)
                real_images = [real_images[i] for i in indices]

            # Save real images
            for i, img in enumerate(real_images):
                img_path = os.path.join(
                    real_pert_dir, f"real_{pert_id}_sample_{i:04d}.png"
                )
                save_image_as_png(img, img_path)

            # Extract features
            real_images_tensor = torch.stack(real_images)

            # Extract features in batches
            real_features_list = []
            with torch.no_grad():
                for i_batch in range(0, real_images_tensor.size(0), batch_size):
                    batch_tensor = real_images_tensor[
                        i_batch : i_batch + batch_size
                    ].to(device)
                    features_batch = encoder(batch_tensor).cpu().numpy()
                    real_features_list.append(features_batch)

            if real_features_list:
                real_features = np.vstack(real_features_list)
                all_real_features.append(real_features)
                all_real_pert_labels.extend([pert_id] * len(real_features))
                print(f"  Real: {len(real_features)} samples saved and processed")

        # Load generated data for specific cell type
        gen_files = find_generated_files_by_perturbation_and_celltype(
            generated_path, pert_id, cell_type_id
        )

        if gen_files:
            # Sample if too many files
            if len(gen_files) > max_samples_per_pert:
                sampled_gen_files = random.sample(gen_files, max_samples_per_pert)
            else:
                sampled_gen_files = gen_files

            temp_generated_images = []

            for i, file_path in enumerate(sampled_gen_files):
                try:
                    img = np.load(file_path)
                    img_tensor = torch.from_numpy(img).float()
                    temp_generated_images.append(img_tensor)

                    # Save generated image
                    img_path = os.path.join(
                        gen_pert_dir, f"generated_{pert_id}_sample_{i:04d}.png"
                    )
                    save_image_as_png(img_tensor, img_path)

                except Exception as e:
                    print(f"Error loading {file_path}: {e}")

            if temp_generated_images:
                generated_images_tensor = torch.stack(temp_generated_images)

                # Extract features in batches
                gen_features_list = []
                with torch.no_grad():
                    for i_batch in range(
                        0, generated_images_tensor.size(0), batch_size
                    ):
                        batch_tensor = generated_images_tensor[
                            i_batch : i_batch + batch_size
                        ].to(device)
                        features_batch = encoder(batch_tensor).cpu().numpy()
                        gen_features_list.append(features_batch)

                if gen_features_list:
                    generated_features = np.vstack(gen_features_list)
                    all_generated_features.append(generated_features)
                    all_generated_pert_labels.extend(
                        [pert_id] * len(generated_features)
                    )
                    print(
                        f"  Generated: {len(generated_features)} samples saved and processed"
                    )

    # Combine all features
    real_features = np.vstack(all_real_features) if all_real_features else np.array([])
    generated_features = (
        np.vstack(all_generated_features) if all_generated_features else np.array([])
    )

    return (
        real_features,
        all_real_pert_labels,
        generated_features,
        all_generated_pert_labels,
    )


def train_and_evaluate_simple(X_train, y_train, X_test, y_test, model_name):
    """
    Simple training and evaluation function.
    """
    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Train model
    model = LogisticRegression(max_iter=1000, random_state=42)
    model.fit(X_train_scaled, y_train)

    # Predict
    y_pred = model.predict(X_test_scaled)
    accuracy = accuracy_score(y_test, y_pred)

    print(f"{model_name}: Accuracy = {accuracy:.4f}")
    print(f"  Train samples: {len(X_train)}, Test samples: {len(X_test)}")
    print(f"  Number of classes: {len(np.unique(y_train))}")

    return accuracy


def main():
    # Set random seeds
    seed = 42
    seed_everything(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Setup device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Configuration
    config_path = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"

    # Parameters
    cell_type_id = 0
    top_k = 10  # Use fewer perturbations for simplicity
    max_samples_per_pert = 500  # Fewer samples for faster processing

    # Load configuration and data module
    config = OmegaConf.load(config_path)
    datamodule = CellDataModule(config)

    # Initialize encoder
    encoder = OpenPhenomEncoder().to(device)

    # Get top perturbations
    perturbation_ids = get_top_perturbations_for_cell_type(
        datamodule, cell_type_id=cell_type_id, top_k=top_k
    )

    print(f"\nAnalyzing perturbations: {perturbation_ids}")
    print(f"Cell type: {cell_type_id}")

    # Create output directory
    output_dir = f"linear_probe_results_with_images_celltype_{cell_type_id}_top_{top_k}"
    os.makedirs(output_dir, exist_ok=True)

    # Load data, extract features, and save images
    print("Loading data, extracting features, and saving images...")
    (
        real_features,
        real_pert_labels,
        generated_features,
        generated_pert_labels,
    ) = extract_features_and_save_images(
        datamodule,
        generated_path,
        perturbation_ids,
        cell_type_id,
        encoder,
        device,
        output_dir,
        max_samples_per_pert=max_samples_per_pert,
        batch_size=64,
    )

    print(f"\nFeature extraction complete!")
    print(f"Real features shape: {real_features.shape}")
    print(f"Generated features shape: {generated_features.shape}")

    # Convert labels to numpy arrays
    real_pert_labels = np.array(real_pert_labels)
    generated_pert_labels = np.array(generated_pert_labels)

    # Encode labels for sklearn
    pert_encoder = LabelEncoder()
    all_pert_labels = np.concatenate([real_pert_labels, generated_pert_labels])
    pert_encoder.fit(all_pert_labels)

    # Transform labels
    real_pert_encoded = pert_encoder.transform(real_pert_labels)
    generated_pert_encoded = pert_encoder.transform(generated_pert_labels)

    print(f"Unique perturbations: {pert_encoder.classes_}")

    # Store results
    results = {}

    # Split real data for consistent testing
    if len(real_features) > 0:
        real_train_features, real_test_features, real_train_labels, real_test_labels = (
            train_test_split(
                real_features,
                real_pert_encoded,
                test_size=0.3,
                random_state=42,
                stratify=real_pert_encoded,
            )
        )

        # 1. Real→Real
        print("\n" + "=" * 50)
        print("EXPERIMENT 1: Real→Real")
        print("=" * 50)
        accuracy = train_and_evaluate_simple(
            real_train_features,
            real_train_labels,
            real_test_features,
            real_test_labels,
            "Real→Real",
        )
        results["Real→Real"] = accuracy

        # 2. Generated→Real
        if len(generated_features) > 0:
            print("\n" + "=" * 50)
            print("EXPERIMENT 2: Generated→Real")
            print("=" * 50)
            accuracy = train_and_evaluate_simple(
                generated_features,
                generated_pert_encoded,
                real_test_features,
                real_test_labels,
                "Generated→Real",
            )
            results["Generated→Real"] = accuracy

            # 3. Real+Generated→Real
            print("\n" + "=" * 50)
            print("EXPERIMENT 3: Real+Generated→Real")
            print("=" * 50)
            combined_train_features = np.vstack(
                [real_train_features, generated_features]
            )
            combined_train_labels = np.concatenate(
                [real_train_labels, generated_pert_encoded]
            )

            accuracy = train_and_evaluate_simple(
                combined_train_features,
                combined_train_labels,
                real_test_features,
                real_test_labels,
                "Real+Generated→Real",
            )
            results["Real+Generated→Real"] = accuracy

    # Create summary
    print("\n" + "=" * 80)
    print("FINAL RESULTS SUMMARY")
    print("=" * 80)
    for experiment, accuracy in results.items():
        print(f"{experiment:25s}: {accuracy:.4f}")

    # Save results to file
    with open(os.path.join(output_dir, "results_summary.txt"), "w") as f:
        f.write(f"Linear Probe Experiments with Image Saving\n")
        f.write(f"Cell Type: {cell_type_id}\n")
        f.write(f"Top {top_k} Perturbations: {perturbation_ids}\n")
        f.write(f"Max samples per perturbation: {max_samples_per_pert}\n\n")
        f.write("Results:\n")
        for experiment, accuracy in results.items():
            f.write(f"{experiment}: {accuracy:.4f}\n")

    # Create a simple visualization
    plt.figure(figsize=(10, 6))
    experiments = list(results.keys())
    accuracies = list(results.values())

    bars = plt.bar(
        experiments, accuracies, alpha=0.7, color=["skyblue", "orange", "lightgreen"]
    )
    plt.ylabel("Accuracy")
    plt.title(f"Linear Probe Results (Cell Type {cell_type_id})")
    plt.ylim(0, 1.0)

    # Add value labels on bars
    for bar, acc in zip(bars, accuracies):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 0.01,
            f"{acc:.3f}",
            ha="center",
            va="bottom",
        )

    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "results_plot.png"), dpi=300, bbox_inches="tight"
    )
    plt.close()

    print(f"\nResults saved to: {output_dir}/")
    print("- Real images saved to: real_images/")
    print("- Generated images saved to: generated_images/")
    print("- Summary saved to: results_summary.txt")
    print("- Plot saved to: results_plot.png")

    print(f"\nImages are organized as:")
    print(f"  real_images/perturbation_<ID>/real_<ID>_sample_<XXXX>.png")
    print(f"  generated_images/perturbation_<ID>/generated_<ID>_sample_<XXXX>.png")


if __name__ == "__main__":
    main()
