import argparse
import glob
import os
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tqdm import tqdm
from selection_methods import match_greedy
try:
    import torch
    import torchvision.transforms as transforms
    from torchvision.models import resnet18

    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    print("Warning: PyTorch not available. Using simple image features instead.")



        
def load_image(image_path, bgr2rgb=False):
    """Load and preprocess an image, with optional BGR-RGB swap."""
    try:
        img = Image.open(image_path).convert("RGB")
        arr = np.array(img)
        return arr
    except Exception as e:
        print(f"Error loading {image_path}: {e}")
        return None


def extract_simple_features(image_array):
    """Extract simple statistical features from an image."""
    if image_array is None:
        return None

    if len(image_array.shape) == 3:
        gray = np.mean(image_array, axis=2)
    else:
        gray = image_array

    features = []
    features.extend(
        [
            np.mean(gray),
            np.std(gray),
            np.min(gray),
            np.max(gray),
            np.median(gray),
        ]
    )

    hist, _ = np.histogram(gray.flatten(), bins=10, range=(0, 255))
    hist = hist / np.sum(hist) 
    features.extend(hist)
    from scipy import ndimage

    edges = ndimage.sobel(gray)
    features.extend(
        [
            np.mean(edges),
            np.std(edges),
        ]
    )

    return np.array(features)



def subset_selector(
    real_resnet_features, real_simple_features, real_clip_features,
    gen_resnet_features, gen_simple_features, gen_clip_features,
    use_resnet=True, selection_features="resnet",
    selection_method="random", count=30
):

    gen_features_new = []
    gen_labels_new = []
    gen_selector_features = gen_resnet_features if selection_features=="resnet" else gen_simple_features
    gen_selector_features = gen_clip_features if selection_features=="clip" else gen_selector_features
    real_selector_features = real_resnet_features if selection_features=="resnet" else real_simple_features
    real_selector_features = real_clip_features if selection_features=="clip" else real_selector_features

    gen_output_features = gen_resnet_features if use_resnet else gen_simple_features
    is_clip = True if selection_features=="clip" else False
    for pert_id in real_selector_features.keys():
        gen_features_subset = gen_selector_features[pert_id]
        real_features_subset = real_selector_features[pert_id]
        selected_indices,_= match_greedy(
            np.array(real_features_subset), np.array(gen_features_subset), count,
            distance=selection_method,using_clip_features=is_clip)
      
        for idx in selected_indices:
            gen_features_new.append(gen_output_features[pert_id][idx])
            gen_labels_new.append(pert_id)
    
  
    generated_features = np.array(gen_features_new)
    
    generated_labels = np.array(gen_labels_new)

    return generated_features, generated_labels

import os, glob, random
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet18
import clip  
def extract_resnet_features(image_array, model, transform, device):
    if image_array is None:
        return None

    try:
        if image_array.max() <= 1.0:
            image_array = (image_array * 255).astype(np.uint8)

        img = Image.fromarray(image_array)
        img_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            feats = model(img_tensor)
            feats = feats.view(feats.size(0), -1)  

        return feats.cpu().numpy().flatten()
    except Exception as e:
        print(f"Error extracting ResNet features: {e}")
        return None

_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

def extract_clip_features(image_array, clip_model, device, normalize=True):
    if image_array is None:
        return None

    try:
        if image_array.max() <= 1.0:
            image_array = (image_array * 255).astype(np.uint8)

        img = Image.fromarray(image_array).convert("RGB")
        img = img.resize((224, 224), Image.BICUBIC)  
        img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
        img = (img - _MEAN) / _STD
        img = img.unsqueeze(0).to(device)

        with torch.no_grad():
            feats = clip_model.encode_image(img).float() 
            if normalize:
                feats = nn.functional.normalize(feats, dim=-1)

        return feats.cpu().numpy().flatten()
    except Exception as e:
        print(f"Error extracting CLIP features: {e}")
        return None

def load_img_features(
    data_dir,
    feature_backend="resnet",   # "resnet" , "clip" , "simple"
    max_samples_per_pert=None,
    is_real=True,
    test_count=20,
    clip_model_name="ViT-B/32",
    upper_count=10
):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = None
    transform = None
    clip_model = None

    
    model = resnet18(weights="IMAGENET1K_V1")
    model = torch.nn.Sequential(*list(model.children())[:-1]) 
    model.eval().to(device)

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ])

    clip_model, _ = clip.load("ViT-B/32", device=device, download_root="./CLIP")
    clip_model.eval()
    real_features = []
    real_labels = []
    real_test_features = []
    real_test_labels = []
    real_resnet_features = {}
    real_simple_features = {} 
    real_clip_features = {}  

    real_upper_features = []
    real_upper_labels = []
    real_dir = os.path.join(data_dir, "real_images" if is_real else "generated_combined")
    if os.path.exists(real_dir):
        perturbation_dirs = [
            d for d in os.listdir(real_dir)
            if os.path.isdir(os.path.join(real_dir, d)) and d.startswith("perturbation_")
        ]

        for pert_dir in tqdm(sorted(perturbation_dirs)):
            pert_id = int(pert_dir.split("_")[1])

            real_resnet_features.setdefault(pert_id, [])
            real_simple_features.setdefault(pert_id, [])
            real_clip_features.setdefault(pert_id, [])

            pert_path = os.path.join(real_dir, pert_dir)
            image_files = glob.glob(os.path.join(pert_path, "*.png"))

            if max_samples_per_pert is not None:
                k = max_samples_per_pert + test_count + upper_count
                if k > len(image_files):
                    print(f"Warning: perturbation {pert_id} has only {len(image_files)} samples, requested {k}. Using all available.")
                    assert 0==1
                image_files = random.sample(image_files,k)
            else:
                image_files = random.sample(image_files, len(image_files))


            test_sample_start = float('inf')
            upper_sample_start = float('inf')
            if max_samples_per_pert is not None:
                test_sample_start = max_samples_per_pert + upper_count
                upper_sample_start = max_samples_per_pert

            for idx, img_file in enumerate(image_files):
                img_array = load_image(img_file) 
                feats_resnet = None
                feats_simple = None
                feats_clip = None
                feats_selected = None

                feats_resnet = extract_resnet_features(img_array, model, transform, device)
                feats_clip = extract_clip_features(img_array, clip_model, device, normalize=True)
                feats_simple = extract_simple_features(img_array)
                if feature_backend == "resnet":
                    feats_selected = feats_resnet

                elif feature_backend == "clip":
                    feats_selected = feats_clip

                elif feature_backend == "simple":
                    feats_selected = feats_simple

                if idx >= test_sample_start:
                    real_test_features.append(feats_selected)
                    real_test_labels.append(pert_id)
                else:
                    real_upper_features.append(feats_selected)
                    real_upper_labels.append(pert_id)
                    
                    if idx < upper_sample_start:
                        real_features.append(feats_selected)
                        real_labels.append(pert_id)

                        if feats_resnet is not None:
                            real_resnet_features[pert_id].append(feats_resnet)
                        if feats_simple is not None:
                            real_simple_features[pert_id].append(feats_simple)
                        if feats_clip is not None:
                            real_clip_features[pert_id].append(feats_clip)
            

    real_features = np.array(real_features) if real_features else np.array([])
    real_labels = np.array(real_labels) if real_labels else np.array([])

    real_upper_features = np.array(real_upper_features) if real_upper_features else np.array([])
    real_upper_labels = np.array(real_upper_labels) if real_upper_labels else np.array([])

    real_test_features = np.array(real_test_features) if real_test_features else np.array([])
    real_test_labels = np.array(real_test_labels) if real_test_labels else np.array([])

    feat_dim = real_features.shape[1] if real_features.size else "N/A"

    if is_real:
        return (
        real_features, real_labels,
        real_resnet_features, real_simple_features,real_clip_features,
        real_test_features, real_test_labels,
        real_upper_features, real_upper_labels
    )
    return (
        real_features, real_labels,
        real_resnet_features, real_simple_features,real_clip_features,
        real_test_features, real_test_labels
    )

def train_and_evaluate_model(X_train, y_train, X_test, y_test, model_name):
    """Train and evaluate a logistic regression model."""
    if len(X_train) == 0 or len(X_test) == 0:
        print(f"Skipping {model_name}: insufficient data,{len(X_train)} train samples, {len(X_test)} test samples")
        return None
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    model = LogisticRegression(max_iter=1000, random_state=42)
    model.fit(X_train_scaled, y_train)

    y_pred = model.predict(X_test_scaled)
    accuracy = accuracy_score(y_test, y_pred)

    print(f"\n{model_name}:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Train samples: {len(X_train)}, Test samples: {len(X_test)}")
    print(f"  Number of classes: {len(np.unique(y_train))}")

    return {
        "model_name": model_name,
        "accuracy": accuracy,
        "n_train": len(X_train),
        "n_test": len(X_test),
        "n_classes": len(np.unique(y_train)),
        "y_true": y_test,
        "y_pred": y_pred,
    }


selection_list = ["<list of methods>"]
def main():
    parser = argparse.ArgumentParser(description="Standalone Linear Probe Analysis")
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Directory containing real_images/ and generated_images/",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="linear_probe_analysis_results",
        help="Output directory for results",
    )
    parser.add_argument(
        "--use_simple_features",
        action="store_true",
        help="Use simple statistical features instead of ResNet features",
    )
    parser.add_argument(
        "--max_samples", type=int, default=None, help="Maximum samples per perturbation"
    )

    parser.add_argument("--selection_count", type=int, default=30, help="Number of samples to select per perturbation")

    parser.add_argument("--selection_features", type=str, default="resnet", choices=["resnet","simple","clip"], help="Features to use for selection")

    parser.add_argument("--experiment_count", type=int, default=10, help="Number of experiment repetitions with different random seeds")
    args = parser.parse_args()


    os.makedirs(args.output_dir, exist_ok=True)

    experiment_results = {}

    use_resnet = not args.use_simple_features

    for exp in range(args.experiment_count):
        random.seed(exp)
        np.random.seed(exp)
        if TORCH_AVAILABLE:
            torch.manual_seed(exp)
        real_features, real_labels, real_resnet_features, real_simple_features, real_clip_features, real_test_features, real_test_labels, real_upper_features, real_upper_labels = load_img_features(
            args.data_dir, feature_backend="resnet" if use_resnet else "simple", max_samples_per_pert=args.max_samples, is_real=True, test_count=30, upper_count=args.selection_count
        )

        _, _, gen_resnet_features, gen_simple_features, gen_clip_features, _, _ = load_img_features(
            args.data_dir, feature_backend="resnet" if use_resnet else "simple", max_samples_per_pert=None, is_real=False
        )
        method_results = {}
        for select_met in selection_list:
            print(f"\nSelection method: {select_met}")
            generated_features, generated_labels = subset_selector(
                    real_resnet_features, real_simple_features, real_clip_features,
                    gen_resnet_features, gen_simple_features, gen_clip_features,
                    use_resnet=use_resnet, selection_features=args.selection_features,
                    selection_method=select_met, count=args.selection_count
                )
            

            if len(real_features) == 0:
                print("Error: No real images found!")
                continue

            if len(generated_features) == 0:
                print("Error: No generated images found!")
                continue

            # Encode labels
            pert_encoder = LabelEncoder()
            all_labels = np.concatenate([real_labels, generated_labels])
            pert_encoder.fit(real_labels)

            real_labels_encoded = pert_encoder.transform(real_labels)
            generated_labels_encoded = pert_encoder.transform(generated_labels)
            real_test_labels_encoded = pert_encoder.transform(real_test_labels)
            real_upper_labels_encoded = pert_encoder.transform(real_upper_labels)

            print(f"Unique perturbations: {pert_encoder.classes_}")
            results = []

            # Experiment 1: Real-Real
            print("\n" + "=" * 60)
            print("EXPERIMENT 1: Real-Real")
            print("=" * 60)
            result = train_and_evaluate_model(
                real_features,
                real_labels_encoded,
                real_test_features,
                real_test_labels_encoded,
                "Real-Real",
            )
            if result:
                results.append(result)

            # Experiment 2: Generated-Real
            print("\n" + "=" * 60)
            print("EXPERIMENT 2: Generated-Real")
            print("=" * 60)
            result = train_and_evaluate_model(
                generated_features,
                generated_labels_encoded,
                real_test_features,
                real_test_labels_encoded,
                "Generated-Real",
            )
            if result:
                results.append(result)

            # Experiment 3: Real+Generated-Real
            print("\n" + "=" * 60)
            print("EXPERIMENT 3: Real+Generated-Real")
            print("=" * 60)
            combined_train_features = np.vstack([real_features, generated_features])
            combined_train_labels = np.concatenate(
                [real_labels_encoded, generated_labels_encoded]
            )

            result = train_and_evaluate_model(
                combined_train_features,
                combined_train_labels,
                real_test_features,
                real_test_labels_encoded,
                "Real+Generated-Real",
            )
            if result:
                results.append(result)


            # Experiment 4: UpperBound Real(upper)-Real
            print("\n" + "=" * 60)
            print("EXPERIMENT 4: UpperBound Real(upper)-Real")
            print("=" * 60)
            result = train_and_evaluate_model(
                real_upper_features,
                real_upper_labels_encoded,
                real_test_features,
                real_test_labels_encoded,
                "UpperBound Real(upper)-Real",
            )
            if result:
                results.append(result)

            print("\n" + "=" * 80)
            print("FINAL RESULTS SUMMARY")
            print("=" * 80)
            method_results[select_met] = results

            summary_data = []
            for result in results:
                print(f"{result['model_name']:25s}: {result['accuracy']:.4f}")
                summary_data.append(
                    {
                        "Experiment": result["model_name"],
                        "Accuracy": result["accuracy"],
                        "Train_Samples": result["n_train"],
                        "Test_Samples": result["n_test"],
                        "Num_Classes": result["n_classes"],
                    }
                )
        experiment_results[exp] = method_results

    for selection_method in selection_list:
        print(f"\n\nResults for selection method: {selection_method}")
        all_results = []
        for exp in range(args.experiment_count):
            if selection_method in experiment_results[exp]:
                all_results.extend(experiment_results[exp][selection_method])
        summary_data = {}
        for result in all_results:
            if result['model_name'] not in summary_data:
                summary_data[result['model_name']] = []
            summary_data[result['model_name']].append(result['accuracy'])
        print(f"{'Experiment':25s} {'Mean Accuracy':15s} {'Std Dev':15s} {'Count':10s}")
        for exp_name, accuracies in summary_data.items():
            mean_acc = np.mean(accuracies)
            std_acc = np.std(accuracies)
            count = len(accuracies)
            print(f"{exp_name:25s} {mean_acc:<15.4f} {std_acc:<15.4f} {count:<10d}")
   
if __name__ == "__main__":
    main()
