import os
import json
from typing import Dict, List, Optional
import argparse
import glob

import torch
from PIL import Image
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor


# Setup HuggingFace cache before importing models
def setup_hf_cache():
    """Setup HuggingFace cache to use local SSD instead of slow NFS."""
    # Check if cache is already configured
    if "HF_HOME" in os.environ:
        print(f"Using existing HF cache: {os.environ['HF_HOME']}")
        return

    # Try to find a local SSD path
    user = os.environ.get("USER", os.environ.get("USERNAME", "user"))
    cache_paths = [
        f"/scratch/{user}/hf-cache",
        f"/tmp/{user}/hf-cache",
        f"/dev/shm/{user}/hf-cache",
        f"/var/tmp/{user}/hf-cache",
    ]

    cache_dir = None
    for path in cache_paths:
        parent = os.path.dirname(path)
        if os.access(parent, os.W_OK):
            cache_dir = path
            break

    if cache_dir:
        os.makedirs(cache_dir, exist_ok=True)
        os.makedirs(f"{cache_dir}/hub", exist_ok=True)
        os.makedirs(f"{cache_dir}/transformers", exist_ok=True)

        os.environ["HF_HOME"] = cache_dir
        os.environ["HF_HUB_CACHE"] = f"{cache_dir}/hub"
        os.environ["TRANSFORMERS_CACHE"] = f"{cache_dir}/transformers"

        print(f"Setup HF cache at: {cache_dir}")
    else:
        print(
            "Warning: Could not find local SSD for cache, using default (may be slow)"
        )


# Setup cache before importing heavy libraries
setup_hf_cache()

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float16

# Cache directory - use dynamic setup from helpers/setup_cache.sh
# This will be set by the cache setup script or fallback to user's local cache
CACHE_DIR = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))

IMAGES_DIR = "../images/images_3_agent"
OUTPUT_DIR = "../alignment/alignment_3_agent"
PROMPTS_PATH = "../prompts/prompts_3_agent.json"

# CLIP Model Configuration
CLIP_MODEL_ID = "openai/clip-vit-large-patch14-336"


def get_clip_alignment_score(
    image_path: str, text: str, model: CLIPModel, processor: CLIPProcessor
) -> float:
    """Calculate CLIP alignment score for a single image and text prompt."""
    try:
        img = Image.open(image_path).convert("RGB")

        # Keep the BatchEncoding; move to device
        inputs = processor(
            text=[text],
            images=[img],
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        inputs = inputs.to(DEVICE)  # preserves BatchEncoding type

        # Pull tensors explicitly; only cast image to float dtype
        pixel_values = inputs["pixel_values"].to(DTYPE)
        input_ids = inputs["input_ids"]
        attention_mask = inputs.get("attention_mask", None)

        with torch.inference_mode():
            image_features = model.get_image_features(pixel_values=pixel_values)
            text_features = model.get_text_features(
                input_ids=input_ids, attention_mask=attention_mask
            )

            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            similarity = (image_features @ text_features.T).squeeze()
            return float(similarity.item())

    except Exception as e:
        print(f"Error processing {image_path} with text '{text}': {e}")
        return 0.0


def calculate_clip_quality(
    image_path: str, model: CLIPModel, processor: CLIPProcessor
) -> float:
    """Calculate image quality score using CLIP."""
    try:
        # Open the image first
        with Image.open(image_path) as im:
            img = im.convert("RGB")

        # Build inputs, then move to device
        inputs = processor(
            text=["High quality image.", "Low quality image."],
            images=[img],
            return_tensors="pt",
            padding=True,
        ).to(DEVICE)

        # Pull tensors; cast only the image to model dtype
        pixel_values = inputs["pixel_values"].to(model.dtype)
        input_ids = inputs["input_ids"]
        attention_mask = inputs.get("attention_mask", None)

        with torch.inference_mode():
            image_feats = model.get_image_features(pixel_values=pixel_values)
            text_feats = model.get_text_features(
                input_ids=input_ids, attention_mask=attention_mask
            )

        # Normalize
        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)
        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)

        # Probabilities over ["High quality", "Low quality"]
        logits = image_feats @ text_feats.T  # (1, 2)
        probs = logits.softmax(dim=-1)       # (1, 2)

        return float(probs[0, 0].item())

    except Exception as e:
        print(f"Error calculating quality for {image_path}: {e}")
        return 0.0


def parse_image_filename(filename: str) -> Optional[Dict]:
    """Parse image filename to extract metadata.

    Expected format: idx{prompt_index:03d}_b1_{bid1:.2f}_b2_{bid2:.2f}_b3_{bid3:.2f}_s{sample_idx:02d}.png
    """
    try:
        # Remove extension
        name = filename.replace(".png", "")

        # Split by underscores
        parts = name.split("_")

        # Extract components
        prompt_idx = int(parts[0].replace("idx", ""))
        bid1 = float(parts[2])
        bid2 = float(parts[4])
        bid3 = float(parts[6])
        sample_idx = int(parts[7].replace("s", ""))

        return {
            "prompt_index": prompt_idx,
            "bids": [bid1, bid2, bid3],
            "sample_index": sample_idx,
        }
    except Exception as e:
        print(f"Error parsing filename {filename}: {e}")
        return None


def load_prompts(prompts_path: str) -> List[Dict]:
    """Load prompts from JSON file."""
    try:
        with open(prompts_path, "r") as f:
            return json.load(f)
    except Exception as e:
        print(f"Error loading prompts from {prompts_path}: {e}")
        return []


def main():
    """Main analysis function."""
    parser = argparse.ArgumentParser(
        description="Analyze CLIP alignment for 3-agent generated images"
    )
    parser.add_argument(
        "--prompt_index", type=int, help="Specific prompt index to process (optional)"
    )
    parser.add_argument(
        "--bid_combination",
        type=str,
        help="Specific bid combination like '1.0,0.0,0.0' (optional)",
    )
    parser.add_argument(
        "--sample_index", type=int, help="Specific sample index to process (optional)"
    )
    args = parser.parse_args()

    # Create output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print(f"Output directory: {OUTPUT_DIR}")

    # Load prompts
    prompts = load_prompts(PROMPTS_PATH)
    if not prompts:
        print("Error: Could not load prompts")
        return

    print(f"Loaded {len(prompts)} prompts")

    # Load CLIP model
    print("Loading CLIP model and processor...")
    clip_model = (
        CLIPModel.from_pretrained(CLIP_MODEL_ID, cache_dir=CACHE_DIR)
        .to(dtype=DTYPE)
        .to(DEVICE)
    )
    clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID, cache_dir=CACHE_DIR)
    print("CLIP model loaded successfully")

    # Find all image files
    image_pattern = os.path.join(IMAGES_DIR, "**", "*.png")
    image_paths = glob.glob(image_pattern, recursive=True)
    print(f"Found {len(image_paths)} images to process")

    if not image_paths:
        print(f"Warning: No images found in {IMAGES_DIR}")
        return

    # Process each image
    processed_count = 0
    for image_path in tqdm(image_paths, desc="Processing images"):
        try:
            # Parse filename to get metadata
            filename = os.path.basename(image_path)
            metadata = parse_image_filename(filename)

            if not metadata:
                print(f"Skipping {filename} - could not parse")
                continue

            prompt_idx = metadata["prompt_index"]
            bids = metadata["bids"]
            sample_idx = metadata["sample_index"]

            # Apply filters if specified
            if args.prompt_index is not None and prompt_idx != args.prompt_index:
                continue
            if args.bid_combination is not None:
                target_bids = [float(x) for x in args.bid_combination.split(",")]
                if not all(abs(a - b) < 0.01 for a, b in zip(bids, target_bids)):
                    continue
            if args.sample_index is not None and sample_idx != args.sample_index:
                continue

            # Get corresponding prompt data
            if prompt_idx >= len(prompts):
                print(f"Prompt index {prompt_idx} out of range, skipping")
                continue

            prompt_data = prompts[prompt_idx]

            # Calculate alignment scores
            base_score = get_clip_alignment_score(
                image_path, prompt_data["base_prompt"], clip_model, clip_processor
            )

            agent1_score = get_clip_alignment_score(
                image_path, prompt_data["agent1_prompt"], clip_model, clip_processor
            )

            agent2_score = get_clip_alignment_score(
                image_path, prompt_data["agent2_prompt"], clip_model, clip_processor
            )

            agent3_score = get_clip_alignment_score(
                image_path, prompt_data["agent3_prompt"], clip_model, clip_processor
            )

            # Calculate image quality
            quality_score = calculate_clip_quality(
                image_path, clip_model, clip_processor
            )

            # Create analysis result
            result = {
                "metadata": {
                    "prompt_index": prompt_idx,
                    "bids": bids,
                    "sample_index": sample_idx,
                    "image_path": image_path,
                    "filename": filename,
                },
                "prompts": {
                    "base_prompt": prompt_data["base_prompt"],
                    "agent1_prompt": prompt_data["agent1_prompt"],
                    "agent2_prompt": prompt_data["agent2_prompt"],
                    "agent3_prompt": prompt_data["agent3_prompt"],
                },
                "alignment_scores": {
                    "base_alignment": base_score,
                    "agent1_alignment": agent1_score,
                    "agent2_alignment": agent2_score,
                    "agent3_alignment": agent3_score,
                },
                "quality_score": quality_score,
                "welfare_metrics": {
                    "weighted_alignment": (
                        agent1_score * bids[0]
                        + agent2_score * bids[1]
                        + agent3_score * bids[2]
                    )
                    / sum(bids)
                    if sum(bids) > 0
                    else 0,
                    "total_welfare": agent1_score * bids[0]
                    + agent2_score * bids[1]
                    + agent3_score * bids[2],
                },
            }

            # Save individual result
            output_filename = f"alignment_p{prompt_idx:03d}_b{bids[0]:.2f}_{bids[1]:.2f}_{bids[2]:.2f}_s{sample_idx:02d}.json"
            output_path = os.path.join(OUTPUT_DIR, output_filename)

            with open(output_path, "w") as f:
                json.dump(result, f, indent=2)

            processed_count += 1

        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            continue

    print(f"Successfully processed {processed_count} images")
    print(f"Results saved to: {OUTPUT_DIR}")

    # Cleanup GPU memory
    print("Cleaning up GPU memory...")
    del clip_model
    del clip_processor
    torch.cuda.empty_cache()
    print("Analysis complete!")


if __name__ == "__main__":
    main()
