"""
Standalone Linear Probe Analysis Script

This script reads saved real and generated cell images and performs linear probe experiments
to evaluate how useful generated samples are for downstream perturbation prediction tasks.

Requirements:
- PIL (Pillow)
- numpy
- scikit-learn
- matplotlib
- torch (or torchvision for transforms)
- tqdm

Usage:
    python standalone_linear_probe.py --data_dir /path/to/saved/images

Expected directory structure:
    data_dir/
    ├── real_images/
    │   ├── perturbation_1108/
    │   │   ├── real_1108_sample_0000.png
    │   │   └── ...
    │   └── perturbation_XXXX/
    └── generated_images/
        ├── perturbation_1108/
        │   ├── generated_1108_sample_0000.png
        │   └── ...
        └── perturbation_XXXX/
"""

import argparse
import glob
import os
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tqdm import tqdm

# Try to import torch for feature extraction, fallback to simple features if not available
try:
    import torch
    import torchvision.transforms as transforms
    from torchvision.models import resnet18

    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    print("Warning: PyTorch not available. Using simple image features instead.")


def load_image(image_path):
    """Load and preprocess an image."""
    try:
        img = Image.open(image_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        return np.array(img)
    except Exception as e:
        print(f"Error loading {image_path}: {e}")
        return None


def extract_simple_features(image_array):
    """Extract simple statistical features from an image."""
    if image_array is None:
        return None

    # Convert to grayscale for simple features
    if len(image_array.shape) == 3:
        gray = np.mean(image_array, axis=2)
    else:
        gray = image_array

    # Extract basic statistical features
    features = []

    # Global statistics
    features.extend(
        [
            np.mean(gray),
            np.std(gray),
            np.min(gray),
            np.max(gray),
            np.median(gray),
        ]
    )

    # Histogram features (10 bins)
    hist, _ = np.histogram(gray.flatten(), bins=10, range=(0, 255))
    hist = hist / np.sum(hist)  # Normalize
    features.extend(hist)

    # Texture features (using local standard deviation)
    from scipy import ndimage

    # Simple edge detection
    edges = ndimage.sobel(gray)
    features.extend(
        [
            np.mean(edges),
            np.std(edges),
        ]
    )

    return np.array(features)


def extract_resnet_features(image_array, model, transform, device):
    """Extract features using a pre-trained ResNet model."""
    if image_array is None:
        return None

    try:
        # Convert numpy array to PIL Image
        if image_array.max() <= 1.0:
            image_array = (image_array * 255).astype(np.uint8)

        img = Image.fromarray(image_array)

        # Apply transforms
        img_tensor = transform(img).unsqueeze(0).to(device)

        # Extract features
        with torch.no_grad():
            features = model(img_tensor)

        return features.cpu().numpy().flatten()
    except Exception as e:
        print(f"Error extracting ResNet features: {e}")
        return None


def load_images_and_extract_features(
    data_dir, use_resnet=True, max_samples_per_pert=None
):
    """
    Load images from the saved directory structure and extract features.

    Args:
        data_dir: Directory containing real_images/ and generated_images/
        use_resnet: Whether to use ResNet features (requires torch)
        max_samples_per_pert: Maximum samples per perturbation (None for all)

    Returns:
        real_features, real_labels, generated_features, generated_labels
    """

    # Setup feature extractor
    if use_resnet and TORCH_AVAILABLE:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        # Load pre-trained ResNet-18 and remove the final classification layer
        model = resnet18(pretrained=True)
        model = torch.nn.Sequential(
            *list(model.children())[:-1]
        )  # Remove final FC layer
        model.eval()
        model.to(device)

        # Image preprocessing
        transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        print("Using ResNet-18 features")
    else:
        model, transform, device = None, None, None
        print("Using simple statistical features")

    real_features = []
    real_labels = []
    generated_features = []
    generated_labels = []

    # Process real images
    real_dir = os.path.join(data_dir, "real_images")
    if os.path.exists(real_dir):
        perturbation_dirs = [
            d
            for d in os.listdir(real_dir)
            if os.path.isdir(os.path.join(real_dir, d))
            and d.startswith("perturbation_")
        ]

        print("Processing real images...")
        for pert_dir in tqdm(sorted(perturbation_dirs)):
            # Extract perturbation ID
            pert_id = int(pert_dir.split("_")[1])
            pert_path = os.path.join(real_dir, pert_dir)

            # Get all image files
            image_files = glob.glob(os.path.join(pert_path, "*.png"))
            if max_samples_per_pert:
                image_files = image_files[:max_samples_per_pert]

            for img_file in image_files:
                img_array = load_image(img_file)

                if use_resnet and TORCH_AVAILABLE:
                    features = extract_resnet_features(
                        img_array, model, transform, device
                    )
                else:
                    features = extract_simple_features(img_array)

                if features is not None:
                    real_features.append(features)
                    real_labels.append(pert_id)

    # Process generated images
    gen_dir = os.path.join(data_dir, "generated_images")
    if os.path.exists(gen_dir):
        perturbation_dirs = [
            d
            for d in os.listdir(gen_dir)
            if os.path.isdir(os.path.join(gen_dir, d)) and d.startswith("perturbation_")
        ]

        print("Processing generated images...")
        for pert_dir in tqdm(sorted(perturbation_dirs)):
            # Extract perturbation ID
            pert_id = int(pert_dir.split("_")[1])
            pert_path = os.path.join(gen_dir, pert_dir)

            # Get all image files
            image_files = glob.glob(os.path.join(pert_path, "*.png"))
            if max_samples_per_pert:
                image_files = image_files[:max_samples_per_pert]

            for img_file in image_files:
                img_array = load_image(img_file)

                if use_resnet and TORCH_AVAILABLE:
                    features = extract_resnet_features(
                        img_array, model, transform, device
                    )
                else:
                    features = extract_simple_features(img_array)

                if features is not None:
                    generated_features.append(features)
                    generated_labels.append(pert_id)

    # Convert to numpy arrays
    real_features = (
        np.array(real_features) if real_features else np.array([]).reshape(0, -1)
    )
    real_labels = np.array(real_labels) if real_labels else np.array([])
    generated_features = (
        np.array(generated_features)
        if generated_features
        else np.array([]).reshape(0, -1)
    )
    generated_labels = np.array(generated_labels) if generated_labels else np.array([])

    print(
        f"Loaded {len(real_features)} real samples and {len(generated_features)} generated samples"
    )
    print(
        f"Feature dimension: {real_features.shape[1] if len(real_features) > 0 else 'N/A'}"
    )

    return real_features, real_labels, generated_features, generated_labels


def train_and_evaluate_model(X_train, y_train, X_test, y_test, model_name):
    """Train and evaluate a logistic regression model."""
    if len(X_train) == 0 or len(X_test) == 0:
        print(f"Skipping {model_name}: insufficient data")
        return None

    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Train model
    model = LogisticRegression(max_iter=1000, random_state=42)
    model.fit(X_train_scaled, y_train)

    # Predict
    y_pred = model.predict(X_test_scaled)
    accuracy = accuracy_score(y_test, y_pred)

    print(f"\n{model_name}:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Train samples: {len(X_train)}, Test samples: {len(X_test)}")
    print(f"  Number of classes: {len(np.unique(y_train))}")

    return {
        "model_name": model_name,
        "accuracy": accuracy,
        "n_train": len(X_train),
        "n_test": len(X_test),
        "n_classes": len(np.unique(y_train)),
        "y_true": y_test,
        "y_pred": y_pred,
    }


def main():
    parser = argparse.ArgumentParser(description="Standalone Linear Probe Analysis")
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Directory containing real_images/ and generated_images/",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="linear_probe_analysis_results",
        help="Output directory for results",
    )
    parser.add_argument(
        "--use_simple_features",
        action="store_true",
        help="Use simple statistical features instead of ResNet features",
    )
    parser.add_argument(
        "--max_samples", type=int, default=None, help="Maximum samples per perturbation"
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed")

    args = parser.parse_args()

    # Set random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    if TORCH_AVAILABLE:
        torch.manual_seed(args.seed)

    print(f"Loading data from: {args.data_dir}")
    print(f"Output directory: {args.output_dir}")

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Load images and extract features
    use_resnet = not args.use_simple_features
    real_features, real_labels, generated_features, generated_labels = (
        load_images_and_extract_features(
            args.data_dir, use_resnet=use_resnet, max_samples_per_pert=args.max_samples
        )
    )

    if len(real_features) == 0:
        print("Error: No real images found!")
        return

    if len(generated_features) == 0:
        print("Error: No generated images found!")
        return

    # Encode labels
    pert_encoder = LabelEncoder()
    all_labels = np.concatenate([real_labels, generated_labels])
    pert_encoder.fit(all_labels)

    real_labels_encoded = pert_encoder.transform(real_labels)
    generated_labels_encoded = pert_encoder.transform(generated_labels)

    print(f"Unique perturbations: {pert_encoder.classes_}")

    # Split real data for consistent testing
    real_train_features, real_test_features, real_train_labels, real_test_labels = (
        train_test_split(
            real_features,
            real_labels_encoded,
            test_size=0.3,
            random_state=args.seed,
            stratify=real_labels_encoded,
        )
    )

    results = []

    # Experiment 1: Real→Real
    print("\n" + "=" * 60)
    print("EXPERIMENT 1: Real→Real")
    print("=" * 60)
    result = train_and_evaluate_model(
        real_train_features,
        real_train_labels,
        real_test_features,
        real_test_labels,
        "Real→Real",
    )
    if result:
        results.append(result)

    # Experiment 2: Generated→Real
    print("\n" + "=" * 60)
    print("EXPERIMENT 2: Generated→Real")
    print("=" * 60)
    result = train_and_evaluate_model(
        generated_features,
        generated_labels_encoded,
        real_test_features,
        real_test_labels,
        "Generated→Real",
    )
    if result:
        results.append(result)

    # Experiment 3: Real+Generated→Real
    print("\n" + "=" * 60)
    print("EXPERIMENT 3: Real+Generated→Real")
    print("=" * 60)
    combined_train_features = np.vstack([real_train_features, generated_features])
    combined_train_labels = np.concatenate(
        [real_train_labels, generated_labels_encoded]
    )

    result = train_and_evaluate_model(
        combined_train_features,
        combined_train_labels,
        real_test_features,
        real_test_labels,
        "Real+Generated→Real",
    )
    if result:
        results.append(result)

    # Calculate per-perturbation accuracies
    per_pert_accuracies = {}

    for result in results:
        experiment_name = result["model_name"]
        y_true = result["y_true"]
        y_pred = result["y_pred"]

        # Calculate accuracy for each perturbation
        pert_accs = {}
        for pert_idx in range(len(pert_encoder.classes_)):
            pert_id = pert_encoder.classes_[pert_idx]
            pert_mask = y_true == pert_idx
            if np.sum(pert_mask) > 0:
                pert_acc = accuracy_score(y_true[pert_mask], y_pred[pert_mask])
                pert_accs[pert_id] = pert_acc
            else:
                pert_accs[pert_id] = 0.0

        per_pert_accuracies[experiment_name] = pert_accs

    # Create per-perturbation accuracy table
    import pandas as pd

    # Prepare data for the table
    pert_ids = list(pert_encoder.classes_)
    table_data = []

    for pert_id in pert_ids:
        row = {"Perturbation_ID": pert_id}
        for exp_name in ["Real→Real", "Generated→Real", "Real+Generated→Real"]:
            if exp_name in per_pert_accuracies:
                row[exp_name] = per_pert_accuracies[exp_name].get(pert_id, 0.0)
            else:
                row[exp_name] = 0.0
        table_data.append(row)

    # Calculate averages
    avg_row = {"Perturbation_ID": "AVERAGE"}
    for exp_name in ["Real→Real", "Generated→Real", "Real+Generated→Real"]:
        if exp_name in per_pert_accuracies:
            avg_acc = np.mean(
                [per_pert_accuracies[exp_name].get(pid, 0.0) for pid in pert_ids]
            )
            avg_row[exp_name] = avg_acc
        else:
            avg_row[exp_name] = 0.0

    table_data.append(avg_row)

    # Create DataFrame
    df_per_pert = pd.DataFrame(table_data)

    # Print the table
    print("\n" + "=" * 80)
    print("PER-PERTURBATION ACCURACY TABLE")
    print("=" * 80)
    print(df_per_pert.to_string(index=False, float_format="%.4f"))

    # Save the per-perturbation table
    df_per_pert.to_csv(
        os.path.join(args.output_dir, "per_perturbation_accuracies.csv"), index=False
    )

    # Summary
    print("\n" + "=" * 80)
    print("OVERALL RESULTS SUMMARY")
    print("=" * 80)
    print(
        f"{'Experiment':<25} {'Accuracy':<10} {'Train Samples':<15} {'Test Samples':<15}"
    )
    print("-" * 80)

    summary_data = []
    for result in results:
        print(
            f"{result['model_name']:<25} {result['accuracy']:<10.4f} {result['n_train']:<15} {result['n_test']:<15}"
        )
        summary_data.append(
            {
                "Experiment": result["model_name"],
                "Accuracy": result["accuracy"],
                "Train_Samples": result["n_train"],
                "Test_Samples": result["n_test"],
                "Num_Classes": result["n_classes"],
            }
        )

    # Save results
    import pandas as pd

    df = pd.DataFrame(summary_data)
    df.to_csv(os.path.join(args.output_dir, "results_summary.csv"), index=False)

    # Create visualization
    if len(results) > 0:
        plt.figure(figsize=(10, 6))
        experiments = [r["model_name"] for r in results]
        accuracies = [r["accuracy"] for r in results]

        bars = plt.bar(
            experiments,
            accuracies,
            alpha=0.7,
            color=["skyblue", "orange", "lightgreen"][: len(experiments)],
        )
        plt.ylabel("Accuracy")
        plt.title("Linear Probe Results")
        plt.ylim(0, 1.0)

        # Add value labels on bars
        for bar, acc in zip(bars, accuracies):
            plt.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_height() + 0.01,
                f"{acc:.3f}",
                ha="center",
                va="bottom",
            )

        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()
        plt.savefig(
            os.path.join(args.output_dir, "results_plot.png"),
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()

    # Save detailed classification reports
    with open(os.path.join(args.output_dir, "detailed_results.txt"), "w") as f:
        f.write("Linear Probe Analysis Results\n")
        f.write("=" * 80 + "\n\n")

        f.write(f"Data directory: {args.data_dir}\n")
        f.write(
            f"Feature type: {'ResNet-18' if use_resnet and TORCH_AVAILABLE else 'Simple statistical'}\n"
        )
        f.write(f"Perturbations analyzed: {list(pert_encoder.classes_)}\n")
        f.write(f"Random seed: {args.seed}\n\n")

        # Summary table first
        f.write("SUMMARY OF RESULTS\n")
        f.write("=" * 80 + "\n")
        f.write(
            f"{'Experiment':<25} {'Accuracy':<10} {'Train Samples':<15} {'Test Samples':<15} {'Classes':<10}\n"
        )
        f.write("-" * 80 + "\n")
        for result in results:
            f.write(
                f"{result['model_name']:<25} {result['accuracy']:<10.4f} {result['n_train']:<15} {result['n_test']:<15} {result['n_classes']:<10}\n"
            )
        f.write("\n\n")

        # Per-perturbation accuracy table
        f.write("PER-PERTURBATION ACCURACY TABLE\n")
        f.write("=" * 80 + "\n")
        f.write(df_per_pert.to_string(index=False, float_format="%.4f"))
        f.write("\n\n")

        # Detailed breakdown per experiment
        f.write("DETAILED BREAKDOWN\n")
        f.write("=" * 80 + "\n")
        for result in results:
            f.write(f"\n{result['model_name']}\n")
            f.write("-" * len(result["model_name"]) + "\n")
            f.write(
                f"OVERALL ACCURACY: {result['accuracy']:.4f} ({result['accuracy']*100:.2f}%)\n"
            )
            f.write(f"Train samples: {result['n_train']}\n")
            f.write(f"Test samples: {result['n_test']}\n")
            f.write(f"Number of classes: {result['n_classes']}\n\n")

            # Classification report
            report = classification_report(
                result["y_true"],
                result["y_pred"],
                target_names=[f"Pert_{p}" for p in pert_encoder.classes_],
            )
            f.write(f"Per-Class Performance:\n{report}\n")

    print(f"\nResults saved to: {args.output_dir}/")
    print("- results_summary.csv: Summary table")
    print("- results_plot.png: Visualization")
    print("- detailed_results.txt: Detailed classification reports")


if __name__ == "__main__":
    main()
