import concurrent.futures
import os
from typing import Callable, Dict, List, Tuple

import hydra
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from torch import nn
from tqdm import tqdm

import wandb
from config_schema import MainConfig  # Assuming config structure remains similar

# Adapt model loading functions and classes from generate_ad_samples_ensemble.py
from generate_ad_samples import (
    ImageFolderWithPaths,  # Might not need this exact one, but image loading needed
)
from generate_ad_samples import (
    MODEL_TO_CLASS,
    ensure_dir,
    hash_training_config,
    set_environment,
    setup_wandb,
)

# Import cosine similarity
from surrogates.loss import cosine_similarity

# Define valid image extensions (copied from blackbox_text_generation.py)
VALID_IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".JPEG"]


# Placeholder for model loading function adapted for white-box transfer evaluation
def get_target_models(cfg: MainConfig) -> Dict[str, nn.Module]:
    """
    Initializes and returns a dictionary of target models based on the configuration.
    Key: Model name, Value: Initialized model instance.
    """
    models = {}
    # Example: Load models specified in cfg.white_transfer.target_models
    target_model_names = getattr(cfg, "white_transfer", {}).get("target_models", [])
    if isinstance(target_model_names, str):
        target_model_names = [target_model_names]

    print(f"Loading target models for evaluation: {target_model_names}")

    for backbone_name in target_model_names:
        if backbone_name in MODEL_TO_CLASS:
            model_cls = MODEL_TO_CLASS[backbone_name]
            # Ensure model is loaded to the correct device from config
            device = getattr(
                cfg.model, "device", "cuda" if torch.cuda.is_available() else "cpu"
            )
            model = model_cls(backbone_name).eval().to(device).requires_grad_(False)
            models[backbone_name] = model
            print(f"Loaded target model: {backbone_name} to {device}")
        else:
            print(
                f"Warning: Unknown target backbone specified: {backbone_name}. Skipping."
            )
            # Or raise ValueError(f"Unknown target backbone: {backbone_name}")
    return models


def find_image_pairs(
    adv_image_dir: str, target_image_dir: str
) -> List[Tuple[str, str, str]]:
    """
    Finds pairs of (adversarial_filename, target_path, adversarial_path).
    Similar logic to blackbox_text_generation.py.
    """
    image_pairs = []
    print(f"Searching for adversarial images in: {adv_image_dir}")
    print(f"Searching for corresponding target images in: {target_image_dir}")

    if not os.path.isdir(adv_image_dir):
        raise FileNotFoundError(
            f"Adversarial image directory not found: {adv_image_dir}"
        )
    if not os.path.isdir(target_image_dir):
        raise FileNotFoundError(f"Target image directory not found: {target_image_dir}")

    # Walk through the adversarial image directory structure (e.g., .../hash/class_folder/image.png)
    for root, _, files in os.walk(adv_image_dir):
        for file in tqdm(files, desc="Finding image pairs", leave=False):
            if any(
                file.lower().endswith(ext.lower()) for ext in VALID_IMAGE_EXTENSIONS
            ):
                adv_path = os.path.join(root, file)
                filename_base = os.path.splitext(os.path.basename(adv_path))[0]
                # The target image should be in target_image_dir/class_folder/filename_base.*
                # The class folder name is the parent of the 'root' directory relative to adv_image_dir
                try:
                    relative_path = os.path.relpath(root, adv_image_dir)
                    class_folder = (
                        os.path.dirname(relative_path)
                        if os.path.dirname(relative_path)
                        else os.path.basename(relative_path)
                    )  # Handle potential nesting levels, assume class is one level up from file if structure is consistent
                    # More robust: Extract class folder based on structure depth if needed
                    # Example: If adv_image_dir is /path/to/output/img/hash
                    # and root is /path/to/output/img/hash/n0123456
                    # class_folder should be n0123456
                    if (
                        class_folder == "."
                    ):  # If adv images are directly in adv_image_dir
                        class_folder = ""  # Adjust target search path accordingly

                    target_found = False
                    tgt_path = None
                    for ext in VALID_IMAGE_EXTENSIONS:
                        # Construct potential target path - Assuming targets are in subdir '1'
                        potential_tgt_path = os.path.join(
                            target_image_dir, "1", filename_base + ext
                        )
                        # print(f"Checking for target: {potential_tgt_path}") # Debugging
                        if os.path.exists(potential_tgt_path):
                            tgt_path = potential_tgt_path
                            target_found = True
                            # print(f"Found target: {tgt_path}") # Debugging
                            break  # Found the target with one extension

                    if target_found:
                        image_pairs.append((file, tgt_path, adv_path))
                    else:
                        print(
                            f"Warning: Target image not found for adversarial image {adv_path} (base: {filename_base}, class: {class_folder}). Skipping."
                        )
                except Exception as e:
                    print(f"Error processing file {adv_path}: {e}")

    print(f"Found {len(image_pairs)} image pairs.")
    if not image_pairs:
        print("Warning: No image pairs found. Check paths and directory structure.")
    return image_pairs


def process_batch(
    batch_pairs: List[Tuple[str, str, str]],
    target_models: Dict[str, nn.Module],
    transform_fn: Callable,
    device: str,
) -> Dict[str, List[Tuple[str, float]]]:
    """
    Processes a batch of image pairs, calculating embedding similarity for each target model.
    Returns a dictionary: {model_name: [(filename, similarity)]}
    """
    batch_similarities = {model_name: [] for model_name in target_models.keys()}
    batch_tgt_tensors = []
    batch_adv_tensors = []
    batch_filenames = []

    # Load and transform images in the batch
    for filename, tgt_path, adv_path in batch_pairs:
        try:
            tgt_img = Image.open(tgt_path).convert("RGB")
            adv_img = Image.open(adv_path).convert("RGB")

            tgt_tensor = transform_fn(tgt_img).to(device)
            adv_tensor = transform_fn(adv_img).to(device)

            batch_tgt_tensors.append(tgt_tensor)
            batch_adv_tensors.append(adv_tensor)
            batch_filenames.append(filename)
        except Exception as e:
            print(
                f"Error loading image pair ({filename}, {tgt_path}, {adv_path}): {e}. Skipping pair."
            )
            continue  # Skip this pair if loading fails

    if not batch_filenames:  # If all pairs in batch failed loading
        return batch_similarities

    # Stack tensors for batch processing
    tgt_tensors_batch = torch.stack(batch_tgt_tensors)
    adv_tensors_batch = torch.stack(batch_adv_tensors)

    # Calculate similarities for each model
    with torch.no_grad():
        for model_name, model in target_models.items():
            try:
                # Extract embeddings (e.g., pooler output, adapt if needed)
                # Assuming model returns dict or object with 'pooler_output' or similar
                # Or just the final embedding tensor directly
                tgt_output = model(
                    tgt_tensors_batch, return_dict=False
                )  # Adjust based on actual model output
                adv_output = model(
                    adv_tensors_batch, return_dict=False
                )  # Adjust based on actual model output

                # Handle potential variations in output structure
                if (
                    isinstance(tgt_output, (dict, OmegaConf))
                    and "pooler_output" in tgt_output
                ):
                    tgt_embeddings = tgt_output["pooler_output"]
                    adv_embeddings = adv_output["pooler_output"]
                elif isinstance(tgt_output, torch.Tensor):
                    tgt_embeddings = tgt_output
                    adv_embeddings = adv_output
                else:
                    # Try accessing common embedding attribute names
                    if hasattr(tgt_output, "pooler_output"):
                        tgt_embeddings = tgt_output.pooler_output
                        adv_embeddings = adv_output.pooler_output
                    elif hasattr(
                        tgt_output, "last_hidden_state"
                    ):  # e.g., for some models
                        # Take CLS token embedding if available
                        tgt_embeddings = tgt_output.last_hidden_state[:, 0]
                        adv_embeddings = adv_output.last_hidden_state[:, 0]
                    else:
                        print(
                            f"Warning: Could not determine embedding from output of model {model_name}. Type: {type(tgt_output)}. Skipping similarity calculation for this model."
                        )
                        continue  # Skip this model if embedding extraction fails

                # Ensure embeddings are 2D [batch, dim]
                if tgt_embeddings.ndim > 2:
                    tgt_embeddings = tgt_embeddings.squeeze()  # Adjust if needed
                if adv_embeddings.ndim > 2:
                    adv_embeddings = adv_embeddings.squeeze()  # Adjust if needed
                if tgt_embeddings.ndim != 2 or adv_embeddings.ndim != 2:
                    print(
                        f"Warning: Unexpected embedding dimension for model {model_name}. Tgt: {tgt_embeddings.shape}, Adv: {adv_embeddings.shape}. Skipping."
                    )
                    continue

                # Calculate cosine similarity for the batch
                similarities = cosine_similarity(
                    tgt_embeddings, adv_embeddings
                )  # Shape: [batch_size]

                # Store results
                for filename, sim in zip(batch_filenames, similarities.cpu().tolist()):
                    batch_similarities[model_name].append((filename, sim))

            except Exception as e:
                print(
                    f"Error during similarity calculation for model {model_name}: {e}"
                )
                # Decide whether to skip model or batch

    return batch_similarities


@hydra.main(
    version_base=None, config_path="config", config_name="ensemble_3models"
)  # Use a relevant default config
def main(cfg: MainConfig):
    _main(cfg)


def _main(cfg: MainConfig):
    set_environment()  # For reproducibility if needed, though less critical here

    # --- WandB Setup ---
    # Determine run name based on config hash for consistency
    config_hash = hash_training_config(cfg)
    prefix = getattr(cfg.wandb, "run_name_prefix", "")
    run_name = (
        f"{prefix}-white_transfer-{config_hash}"
        if prefix
        else f"white_transfer-{config_hash}"
    )
    # Ensure project/entity are set in config or environment
    setup_wandb(cfg, tags=["white-box-transfer", config_hash], name=run_name)
    wandb.config.update(
        OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
    )

    # --- Configuration & Paths ---
    print(f"Evaluating transfer for config hash: {config_hash}")
    device = getattr(
        cfg.model, "device", "cuda" if torch.cuda.is_available() else "cpu"
    )
    # Directory where adversarial images were saved by generate_ad_samples_ensemble.py
    # Structure: cfg.data.output / "img" / config_hash / class_name / image_name.png
    adv_image_base_dir = os.path.join(cfg.data.output, "img", config_hash)
    # Directory containing the original target images used in generation
    target_image_dir = cfg.data.tgt_data_path  # e.g., data/imagenet-targets/1

    # --- Load Target Models ---
    target_models = get_target_models(cfg)
    if not target_models:
        print("Error: No target models loaded. Exiting.")
        wandb.finish()
        return

    # --- Image Transformations ---
    # Use the same input resolution as the models expect
    input_res = cfg.model.input_res
    transform_fn = transforms.Compose(
        [
            transforms.Resize(
                input_res,
                interpolation=torchvision.transforms.InterpolationMode.BICUBIC,
            ),
            transforms.CenterCrop(input_res),
            transforms.ToTensor(),  # Converts to [0, 1] float tensor
            # Add normalization if the target models require it (check model specifics)
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Example for ImageNet models
        ]
    )

    # --- Find Image Pairs ---
    try:
        all_image_pairs = find_image_pairs(adv_image_base_dir, target_image_dir)
    except FileNotFoundError as e:
        print(f"Error finding image pairs: {e}")
        wandb.finish()
        return

    if not all_image_pairs:
        print("No image pairs found to evaluate. Exiting.")
        wandb.finish()
        return

    # --- Processing ---
    batch_size = getattr(
        cfg.white_transfer, "batch_size", cfg.data.batch_size
    )  # Use a specific batch size or default
    num_workers = getattr(
        cfg.white_transfer, "num_workers", 4
    )  # Parallel loading? Maybe less critical here.
    all_similarities = {model_name: [] for model_name in target_models.keys()}

    num_batches = (len(all_image_pairs) + batch_size - 1) // batch_size

    print(
        f"Processing {len(all_image_pairs)} pairs in {num_batches} batches of size {batch_size}..."
    )

    for i in tqdm(range(num_batches), desc="Evaluating Batches"):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(all_image_pairs))
        batch_pairs = all_image_pairs[start_idx:end_idx]

        if not batch_pairs:
            continue

        # Process the batch
        batch_results = process_batch(
            batch_pairs=batch_pairs,
            target_models=target_models,
            transform_fn=transform_fn,
            device=device,
        )

        # Log individual similarities and accumulate
        log_data = {"batch_processed": i}
        for model_name, results in batch_results.items():
            all_similarities[model_name].extend(results)
            if results:  # Only log if there are results for this model in this batch
                avg_sim = np.mean([sim for _, sim in results])
                log_data[f"average_similarity/{model_name}"] = avg_sim
                # Log individual image similarities (might be too much data for large datasets)
                # for filename, sim in results:
                #     log_data[f"similarity/{model_name}/{filename}"] = sim

        wandb.log(log_data)

    # --- Final Summary ---
    print("--- Final Average Similarities ---")
    final_summary = {}
    for model_name, results in all_similarities.items():
        if results:
            avg_sim = np.mean([sim for _, sim in results])
            print(f"{model_name}: {avg_sim:.4f}")
            final_summary[f"final_avg_similarity/{model_name}"] = avg_sim
        else:
            print(f"{model_name}: No results")

    wandb.summary.update(final_summary)

    # --- Save Results (Optional) ---
    # Consider saving detailed similarity scores to a file
    output_dir = getattr(
        cfg.white_transfer,
        "output_dir",
        os.path.join(cfg.data.output, "transfer_results", config_hash),
    )
    ensure_dir(output_dir)
    results_file = os.path.join(
        output_dir, "white_transfer_similarities.txt"
    )  # Or .json
    print(f"Saving detailed results to {results_file}")
    with open(results_file, "w") as f:
        import json  # Use JSON for easier parsing

        # Structure: { model_name: { filename: similarity } }
        output_data = {}
        for model_name, results in all_similarities.items():
            output_data[model_name] = {fname: sim for fname, sim in results}
        json.dump(output_data, f, indent=2)

    print("White-box transfer evaluation complete.")
    wandb.finish()


if __name__ == "__main__":
    main()
