import os
import random
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule
from sc_perturb.openphenom import OpenPhenomEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tqdm import tqdm

warnings.filterwarnings("ignore")


def find_generated_files_by_perturbation_and_celltype(
    generated_path, perturbation_id, cell_type_id
):
    """
    Find all generated numpy files for a specific perturbation ID and cell type ID.
    """
    import glob

    pert_folder = f"p{perturbation_id}"
    pert_path = os.path.join(generated_path, pert_folder)

    if not os.path.exists(pert_path):
        return []

    # Pattern to match cell type in filenames (p<pid>_c<cell_type_id>_sample<sample_id>.npy)
    pattern = f"_c{cell_type_id}_sample"

    # Find all .npy files in the perturbation folder
    npy_files = glob.glob(os.path.join(pert_path, "*.npy"))

    # Filter files that match both the perturbation ID and cell type pattern
    filtered_files = [f for f in npy_files if pattern in f]

    return filtered_files


def extract_features_batch(images_tensor, encoder, device, batch_size=64):
    """
    Extract features from images using the encoder in batches.
    """
    features_list = []
    if images_tensor is not None and images_tensor.size(0) > 0:
        with torch.no_grad():
            for i_batch in range(0, images_tensor.size(0), batch_size):
                batch_tensor = images_tensor[i_batch : i_batch + batch_size].to(device)
                features_batch = encoder(batch_tensor).cpu().numpy()
                features_list.append(features_batch)
        return np.vstack(features_list) if features_list else np.array([])
    else:
        return np.array([])


def load_and_extract_features(
    datamodule,
    generated_path,
    perturbation_ids,
    cell_type_ids,
    encoder,
    device,
    max_samples_per_pert=500,
    batch_size=64,
):
    """
    Load real and generated data and extract features for specified perturbations and cell types.

    Returns:
        real_features, real_pert_labels, real_cell_labels,
        generated_features, generated_pert_labels, generated_cell_labels
    """
    all_real_features = []
    all_real_pert_labels = []
    all_real_cell_labels = []

    all_generated_features = []
    all_generated_pert_labels = []
    all_generated_cell_labels = []

    for pert_id in tqdm(perturbation_ids, desc="Processing perturbations"):
        print(f"\nProcessing perturbation ID: {pert_id}")

        # Load real data
        real_filtered_dataset = datamodule.filter_samples(
            perturbation_id=pert_id,
            cell_type_id=None,  # Get all cell types
        )

        if real_filtered_dataset is not None and len(real_filtered_dataset) > 0:
            real_images = [
                real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
            ]
            real_cell_types = [
                real_filtered_dataset[i][2] for i in range(len(real_filtered_dataset))
            ]

            # Limit samples if needed
            if len(real_images) > max_samples_per_pert:
                indices = random.sample(range(len(real_images)), max_samples_per_pert)
                real_images = [real_images[i] for i in indices]
                real_cell_types = [real_cell_types[i] for i in indices]

            real_images_tensor = torch.stack(real_images)
            real_features = extract_features_batch(
                real_images_tensor, encoder, device, batch_size
            )

            if real_features.size > 0:
                all_real_features.append(real_features)
                all_real_pert_labels.extend([pert_id] * len(real_features))
                all_real_cell_labels.extend(real_cell_types)
                print(f"  Real: {len(real_features)} samples")

        # Load generated data
        all_gen_files_info = []
        for ct_id in cell_type_ids:
            files_for_ct = find_generated_files_by_perturbation_and_celltype(
                generated_path, pert_id, ct_id
            )
            for f_path in files_for_ct:
                all_gen_files_info.append({"path": f_path, "cell_type": ct_id})

        if all_gen_files_info:
            # Sample if too many files
            if len(all_gen_files_info) > max_samples_per_pert:
                sampled_gen_files_info = random.sample(
                    all_gen_files_info, max_samples_per_pert
                )
            else:
                sampled_gen_files_info = all_gen_files_info

            temp_generated_images = []
            temp_generated_cell_types = []

            for file_info in sampled_gen_files_info:
                try:
                    img = np.load(file_info["path"])
                    temp_generated_images.append(torch.from_numpy(img).float())
                    temp_generated_cell_types.append(file_info["cell_type"])
                except Exception as e:
                    print(f"Error loading {file_info['path']}: {e}")

            if temp_generated_images:
                generated_images_tensor = torch.stack(temp_generated_images)
                generated_features = extract_features_batch(
                    generated_images_tensor, encoder, device, batch_size
                )

                if generated_features.size > 0:
                    all_generated_features.append(generated_features)
                    all_generated_pert_labels.extend(
                        [pert_id] * len(generated_features)
                    )
                    all_generated_cell_labels.extend(temp_generated_cell_types)
                    print(f"  Generated: {len(generated_features)} samples")

    # Combine all features
    real_features = np.vstack(all_real_features) if all_real_features else np.array([])
    generated_features = (
        np.vstack(all_generated_features) if all_generated_features else np.array([])
    )

    return (
        real_features,
        all_real_pert_labels,
        all_real_cell_labels,
        generated_features,
        all_generated_pert_labels,
        all_generated_cell_labels,
    )


def train_and_evaluate_model(X_train, y_train, X_test, y_test, model_name, task_name):
    """
    Train a logistic regression model and evaluate it.
    """
    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Train model
    model = LogisticRegression(max_iter=1000, random_state=42)
    model.fit(X_train_scaled, y_train)

    # Predict
    y_pred = model.predict(X_test_scaled)
    accuracy = accuracy_score(y_test, y_pred)

    print(f"\n{model_name} - {task_name}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Number of classes: {len(np.unique(y_train))}")
    print(f"Train samples: {len(X_train)}, Test samples: {len(X_test)}")

    # Cross-validation on training set
    cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=5)
    print(f"CV Accuracy: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")

    return {
        "model_name": model_name,
        "task": task_name,
        "accuracy": accuracy,
        "cv_mean": cv_scores.mean(),
        "cv_std": cv_scores.std(),
        "n_train": len(X_train),
        "n_test": len(X_test),
        "n_classes": len(np.unique(y_train)),
        "model": model,
        "scaler": scaler,
        "y_true": y_test,
        "y_pred": y_pred,
    }


def plot_confusion_matrix(y_true, y_pred, labels, title, save_path):
    """
    Plot and save confusion matrix.
    """
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels
    )
    plt.title(title, fontsize=16)
    plt.xlabel("Predicted", fontsize=14)
    plt.ylabel("True", fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()


def main():
    # Set random seeds
    seed = 42
    seed_everything(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Setup device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load configuration and data module
    config_path = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"

    config = OmegaConf.load(config_path)
    datamodule = CellDataModule(config)

    # Initialize encoder
    encoder = OpenPhenomEncoder().to(device)

    # Define perturbations and cell types to analyze
    # Use a subset for faster processing, increase as needed
    perturbation_ids = [1138, 1137, 1108, 1124]

    cell_type_ids = [0, 1, 2, 3]

    print(f"Analyzing perturbations: {perturbation_ids}")
    print(f"Cell types: {cell_type_ids}")

    # Create output directory
    output_dir = "linear_model_results"
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(f"{output_dir}/confusion_matrices", exist_ok=True)

    # Load and extract features
    print("Loading data and extracting features...")
    (
        real_features,
        real_pert_labels,
        real_cell_labels,
        generated_features,
        generated_pert_labels,
        generated_cell_labels,
    ) = load_and_extract_features(
        datamodule,
        generated_path,
        perturbation_ids,
        cell_type_ids,
        encoder,
        device,
        max_samples_per_pert=300,
        batch_size=64,
    )

    print(f"\nFeature extraction complete!")
    print(f"Real features shape: {real_features.shape}")
    print(f"Generated features shape: {generated_features.shape}")

    # Convert labels to numpy arrays
    real_pert_labels = np.array(real_pert_labels)
    real_cell_labels = np.array(real_cell_labels)
    generated_pert_labels = np.array(generated_pert_labels)
    generated_cell_labels = np.array(generated_cell_labels)

    # Encode labels for sklearn
    pert_encoder = LabelEncoder()
    cell_encoder = LabelEncoder()

    # Fit on all labels (real + generated)
    all_pert_labels = np.concatenate([real_pert_labels, generated_pert_labels])
    all_cell_labels = np.concatenate([real_cell_labels, generated_cell_labels])

    pert_encoder.fit(all_pert_labels)
    cell_encoder.fit(all_cell_labels)

    # Transform labels
    real_pert_encoded = pert_encoder.transform(real_pert_labels)
    real_cell_encoded = cell_encoder.transform(real_cell_labels)
    generated_pert_encoded = pert_encoder.transform(generated_pert_labels)
    generated_cell_encoded = cell_encoder.transform(generated_cell_labels)

    print(f"Unique perturbations: {pert_encoder.classes_}")
    print(f"Unique cell types: {cell_encoder.classes_}")

    # Store results
    results = []

    # 1. Real Cell Type Predictor (train on real, test on real)
    if len(real_features) > 0:
        X_train, X_test, y_train, y_test = train_test_split(
            real_features,
            real_cell_encoded,
            test_size=0.3,
            random_state=42,
            stratify=real_cell_encoded,
        )
        result = train_and_evaluate_model(
            X_train, y_train, X_test, y_test, "Real→Real", "Cell Type Prediction"
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            y_test,
            result["y_pred"],
            cell_encoder.classes_,
            "Real→Real Cell Type Prediction",
            f"{output_dir}/confusion_matrices/real_real_celltype.png",
        )

    # 2. Real Perturbation Predictor (train on real, test on real)
    if len(real_features) > 0:
        X_train, X_test, y_train, y_test = train_test_split(
            real_features,
            real_pert_encoded,
            test_size=0.3,
            random_state=42,
            stratify=real_pert_encoded,
        )
        result = train_and_evaluate_model(
            X_train, y_train, X_test, y_test, "Real→Real", "Perturbation Prediction"
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            y_test,
            result["y_pred"],
            pert_encoder.classes_,
            "Real→Real Perturbation Prediction",
            f"{output_dir}/confusion_matrices/real_real_perturbation.png",
        )

    # 3. Generated Cell Type Predictor (train on generated, test on generated)
    if len(generated_features) > 0:
        X_train, X_test, y_train, y_test = train_test_split(
            generated_features,
            generated_cell_encoded,
            test_size=0.3,
            random_state=42,
            stratify=generated_cell_encoded,
        )
        result = train_and_evaluate_model(
            X_train,
            y_train,
            X_test,
            y_test,
            "Generated→Generated",
            "Cell Type Prediction",
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            y_test,
            result["y_pred"],
            cell_encoder.classes_,
            "Generated→Generated Cell Type Prediction",
            f"{output_dir}/confusion_matrices/gen_gen_celltype.png",
        )

    # 4. Generated Perturbation Predictor (train on generated, test on generated)
    if len(generated_features) > 0:
        X_train, X_test, y_train, y_test = train_test_split(
            generated_features,
            generated_pert_encoded,
            test_size=0.3,
            random_state=42,
            stratify=generated_pert_encoded,
        )
        result = train_and_evaluate_model(
            X_train,
            y_train,
            X_test,
            y_test,
            "Generated→Generated",
            "Perturbation Prediction",
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            y_test,
            result["y_pred"],
            pert_encoder.classes_,
            "Generated→Generated Perturbation Prediction",
            f"{output_dir}/confusion_matrices/gen_gen_perturbation.png",
        )

    # 5. Cross-domain: Train on generated, test on real (Cell Type)
    if len(generated_features) > 0 and len(real_features) > 0:
        result = train_and_evaluate_model(
            generated_features,
            generated_cell_encoded,
            real_features,
            real_cell_encoded,
            "Generated→Real",
            "Cell Type Prediction",
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            result["y_true"],
            result["y_pred"],
            cell_encoder.classes_,
            "Generated→Real Cell Type Prediction",
            f"{output_dir}/confusion_matrices/gen_real_celltype.png",
        )

    # 6. Cross-domain: Train on generated, test on real (Perturbation)
    if len(generated_features) > 0 and len(real_features) > 0:
        result = train_and_evaluate_model(
            generated_features,
            generated_pert_encoded,
            real_features,
            real_pert_encoded,
            "Generated→Real",
            "Perturbation Prediction",
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            result["y_true"],
            result["y_pred"],
            pert_encoder.classes_,
            "Generated→Real Perturbation Prediction",
            f"{output_dir}/confusion_matrices/gen_real_perturbation.png",
        )

    # Create summary dataframe
    summary_data = []
    for result in results:
        summary_data.append(
            {
                "Model": result["model_name"],
                "Task": result["task"],
                "Test_Accuracy": result["accuracy"],
                "CV_Mean": result["cv_mean"],
                "CV_Std": result["cv_std"],
                "N_Train": result["n_train"],
                "N_Test": result["n_test"],
                "N_Classes": result["n_classes"],
            }
        )

    summary_df = pd.DataFrame(summary_data)

    # Save summary
    summary_df.to_csv(f"{output_dir}/model_performance_summary.csv", index=False)

    # Print summary
    print("\n" + "=" * 80)
    print("FINAL RESULTS SUMMARY")
    print("=" * 80)
    print(summary_df.to_string(index=False))

    # Create visualization
    plt.figure(figsize=(14, 8))

    # Group by task and model type
    cell_type_results = summary_df[summary_df["Task"] == "Cell Type Prediction"]
    pert_results = summary_df[summary_df["Task"] == "Perturbation Prediction"]

    x_pos = np.arange(len(cell_type_results))
    width = 0.35

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # Cell type prediction accuracies
    ax1.bar(
        x_pos,
        cell_type_results["Test_Accuracy"],
        width,
        label="Test Accuracy",
        alpha=0.8,
        color="skyblue",
    )
    ax1.bar(
        x_pos,
        cell_type_results["CV_Mean"],
        width,
        label="CV Mean",
        alpha=0.6,
        color="orange",
    )
    ax1.set_xlabel("Model Type")
    ax1.set_ylabel("Accuracy")
    ax1.set_title("Cell Type Prediction Performance")
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels(cell_type_results["Model"], rotation=45)
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Perturbation prediction accuracies
    x_pos2 = np.arange(len(pert_results))
    ax2.bar(
        x_pos2,
        pert_results["Test_Accuracy"],
        width,
        label="Test Accuracy",
        alpha=0.8,
        color="lightcoral",
    )
    ax2.bar(
        x_pos2, pert_results["CV_Mean"], width, label="CV Mean", alpha=0.6, color="gold"
    )
    ax2.set_xlabel("Model Type")
    ax2.set_ylabel("Accuracy")
    ax2.set_title("Perturbation Prediction Performance")
    ax2.set_xticks(x_pos2)
    ax2.set_xticklabels(pert_results["Model"], rotation=45)
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(
        f"{output_dir}/performance_comparison.png", dpi=300, bbox_inches="tight"
    )
    plt.close()

    # Save detailed classification reports
    with open(f"{output_dir}/detailed_reports.txt", "w") as f:
        for i, result in enumerate(results):
            f.write(f"\n{'='*60}\n")
            f.write(f"{result['model_name']} - {result['task']}\n")
            f.write(f"{'='*60}\n")
            if result["task"] == "Cell Type Prediction":
                labels = cell_encoder.classes_
            else:
                labels = pert_encoder.classes_

            report = classification_report(
                result["y_true"],
                result["y_pred"],
                target_names=[str(l) for l in labels],
            )
            f.write(report)
            f.write(f"\n")

    print(f"\nResults saved to {output_dir}/")
    print("- model_performance_summary.csv: Summary of all model performances")
    print("- performance_comparison.png: Visualization of results")
    print("- confusion_matrices/: Individual confusion matrices")
    print("- detailed_reports.txt: Detailed classification reports")


if __name__ == "__main__":
    main()
