#!/usr/bin/env python3
"""
Readout comparison experiment for Figure 4.

Compares three geometry-based readout methods on the Attribute Binding task:
1. Cosine similarity (baseline)
2. k-NN weighted voting
3. Local PCA projection
"""

import argparse
import json
import sys
from pathlib import Path

import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.models import load_model
from src.benchmarks import (
    AttributeBindingBenchmark,
    SyntheticShapeGenerator,
    get_default_transform,
)


class ReadoutComparisonBenchmark:
    """
    Attribute Binding with multiple readout methods.

    Methods:
    1. Cosine: Direct cosine similarity
    2. kNN: Weighted voting from k nearest neighbors
    3. Local PCA: Project onto local PC subspace
    """

    def __init__(
        self,
        device: str = "cuda",
        num_samples: int = 500,
        n_distractors: int = 3,
        k_neighbors: int = 60,
        pca_components: int = 32,
        pca_neighbors: int = 50,
        seed: int = 42,
    ):
        self.device = device
        self.num_samples = num_samples
        self.n_distractors = n_distractors
        self.k_neighbors = k_neighbors
        self.pca_components = pca_components
        self.pca_neighbors = pca_neighbors
        self.generator = SyntheticShapeGenerator(seed=seed)
        self.transform = get_default_transform()
        self.rng = np.random.RandomState(seed)

        # Cache for embeddings (for kNN and PCA)
        self.embedding_cache = []

    def generate_trial(self):
        """Generate trial with disjoint colors (same as AttributeBindingBenchmark)."""
        shapes = self.generator.SHAPES
        colors = list(self.generator.COLORS.keys())

        shape1, shape2 = self.rng.choice(shapes, 2, replace=False)
        query_colors = self.rng.choice(colors, 2, replace=False)
        color1_q, color2_q = query_colors[0], query_colors[1]

        query_objects = [
            {'shape': shape1, 'color': color1_q, 'position': 'left'},
            {'shape': shape2, 'color': color2_q, 'position': 'right'},
        ]
        query_img = self.generator.create_image(query_objects)

        remaining_colors = [c for c in colors if c not in query_colors]
        target_colors = self.rng.choice(remaining_colors, 2, replace=False)
        color1_t, color2_t = target_colors[0], target_colors[1]

        correct_objects = [
            {'shape': shape1, 'color': color1_t, 'position': 'left'},
            {'shape': shape2, 'color': color2_t, 'position': 'right'},
        ]
        correct_img = self.generator.create_image(correct_objects)

        distractors = []
        d1_objects = [
            {'shape': shape2, 'color': color1_t, 'position': 'left'},
            {'shape': shape1, 'color': color2_t, 'position': 'right'},
        ]
        distractors.append(self.generator.create_image(d1_objects))

        other_colors = self.rng.choice([c for c in colors if c not in [color1_t, color2_t]], 2, replace=True)
        d2_objects = [
            {'shape': shape2, 'color': other_colors[0], 'position': 'left'},
            {'shape': shape1, 'color': other_colors[1], 'position': 'right'},
        ]
        distractors.append(self.generator.create_image(d2_objects))

        d3_objects = [
            {'shape': shape1, 'color': color1_t, 'position': 'left'},
            {'shape': shape1, 'color': color2_t, 'position': 'right'},
        ]
        distractors.append(self.generator.create_image(d3_objects))

        candidates = [correct_img] + distractors[:self.n_distractors]
        indices = list(range(len(candidates)))
        self.rng.shuffle(indices)
        candidates = [candidates[i] for i in indices]
        correct_idx = indices.index(0)

        return query_img, candidates, correct_idx

    def _encode(self, model, img):
        x = self.transform(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            emb = model(x)
        if isinstance(emb, dict):
            emb = emb.get("image", emb.get("features", list(emb.values())[0]))
        if isinstance(emb, tuple):
            emb = emb[0]
        if emb.dim() == 4:
            emb = emb.mean(dim=[2, 3])
        elif emb.dim() == 3:
            emb = emb[:, 0]
        return F.normalize(emb.float(), dim=-1)

    def cosine_score(self, query_emb, cand_embs):
        """Standard cosine similarity."""
        return F.cosine_similarity(query_emb, cand_embs)

    def knn_score(self, query_emb, cand_embs):
        """k-NN weighted scoring using cached embeddings."""
        if len(self.embedding_cache) < self.k_neighbors:
            return self.cosine_score(query_emb, cand_embs)

        cache = torch.stack(self.embedding_cache[-500:])  # Use recent embeddings

        # Find k nearest neighbors to query
        sims = F.cosine_similarity(query_emb, cache)
        _, topk_idx = sims.topk(min(self.k_neighbors, len(cache)))
        neighbors = cache[topk_idx]

        # Score candidates by average similarity to neighbors
        scores = []
        for c in range(cand_embs.shape[0]):
            cand = cand_embs[c:c+1]
            neighbor_sims = F.cosine_similarity(cand, neighbors)
            scores.append(neighbor_sims.mean().item())

        return torch.tensor(scores, device=self.device)

    def local_pca_score(self, query_emb, cand_embs):
        """Local PCA projection scoring."""
        if len(self.embedding_cache) < self.pca_neighbors:
            return self.cosine_score(query_emb, cand_embs)

        cache = torch.stack(self.embedding_cache[-500:])

        # Find local neighborhood
        sims = F.cosine_similarity(query_emb, cache)
        _, topk_idx = sims.topk(min(self.pca_neighbors, len(cache)))
        neighbors = cache[topk_idx]

        # Compute local PCA
        centered = neighbors - neighbors.mean(dim=0, keepdim=True)
        try:
            U, S, V = torch.linalg.svd(centered, full_matrices=False)
            pcs = V[:min(self.pca_components, V.shape[0])]
        except:
            return self.cosine_score(query_emb, cand_embs)

        # Project query and candidates onto local subspace
        query_proj = query_emb @ pcs.T
        cand_projs = cand_embs @ pcs.T

        # Score by distance in local subspace
        distances = (cand_projs - query_proj).norm(dim=-1)
        max_dist = distances.max()
        if max_dist > 0:
            scores = 1 - distances / max_dist
        else:
            scores = torch.ones_like(distances)

        return scores

    def evaluate(self, model):
        """Evaluate all readout methods."""
        model = model.eval().to(self.device)
        self.embedding_cache = []

        results = {
            'cosine': {'correct': 0, 'total': 0},
            'knn': {'correct': 0, 'total': 0},
            'local_pca': {'correct': 0, 'total': 0},
        }

        for _ in tqdm(range(self.num_samples), desc="Readout comparison"):
            query_img, candidates, correct_idx = self.generate_trial()

            query_emb = self._encode(model, query_img)
            cand_embs = torch.cat([self._encode(model, c) for c in candidates], dim=0)

            # Cache embeddings
            self.embedding_cache.append(query_emb.squeeze())
            for c in cand_embs:
                self.embedding_cache.append(c)

            # Evaluate each method
            for method_name, score_fn in [
                ('cosine', self.cosine_score),
                ('knn', self.knn_score),
                ('local_pca', self.local_pca_score),
            ]:
                scores = score_fn(query_emb, cand_embs)
                pred_idx = scores.argmax().item()
                if pred_idx == correct_idx:
                    results[method_name]['correct'] += 1
                results[method_name]['total'] += 1

        # Compute accuracies
        for method in results:
            total = results[method]['total']
            correct = results[method]['correct']
            results[method]['accuracy'] = correct / total if total > 0 else 0

        return results


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--models", nargs="+", default=None)
    parser.add_argument("--output", type=str, default="outputs/fig4_readout.json")
    parser.add_argument("--num-samples", type=int, default=500)
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")

    # Figure 4 models (representative set)
    target_models = args.models or [
        "barlow_twins_resnet50",
        "vicreg_resnet50",
        "clip_vitb16",
        "dinov2_vitb14",
        "mae_vitb16",
    ]

    all_results = {}

    for model_name in target_models:
        print(f"\n{'='*60}")
        print(f"Evaluating: {model_name}")
        print('='*60)

        try:
            model_wrapper = load_model(model_name, device=device)
            model = model_wrapper.model

            bench = ReadoutComparisonBenchmark(
                device=device,
                num_samples=args.num_samples,
            )

            results = bench.evaluate(model)

            all_results[model_name] = {
                'cosine': results['cosine']['accuracy'],
                'knn': results['knn']['accuracy'],
                'local_pca': results['local_pca']['accuracy'],
            }

            print(f"  Cosine:    {results['cosine']['accuracy']:.4f}")
            print(f"  kNN:       {results['knn']['accuracy']:.4f}")
            print(f"  Local PCA: {results['local_pca']['accuracy']:.4f}")

            del model_wrapper, model
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"  Error: {e}")
            continue

    # Save results
    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(all_results, f, indent=2)

    print(f"\nSaved to {output_path}")

    # Print summary table
    print("\n" + "="*60)
    print("READOUT COMPARISON (Figure 4)")
    print("="*60)
    print(f"{'Model':<25} {'Cosine':>10} {'kNN':>10} {'Local PCA':>10}")
    print("-"*60)
    for model, data in all_results.items():
        print(f"{model:<25} {data['cosine']:>10.4f} {data['knn']:>10.4f} {data['local_pca']:>10.4f}")


if __name__ == "__main__":
    main()
