"""
Compositional binding benchmarks for vision encoders.

Two tasks:
1. Same/Different: Structural discrimination control
2. Attribute Binding: Main compositional binding task
"""

import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageDraw
from typing import Dict, List, Tuple, Optional, Callable
from dataclasses import dataclass
from tqdm import tqdm
from torchvision import transforms


def get_default_transform(image_size: int = 224) -> transforms.Compose:
    """Default ImageNet preprocessing."""
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


@dataclass
class BenchmarkResult:
    """Results from a benchmark evaluation."""
    task_name: str
    accuracy: float
    num_samples: int
    chance_level: float


class SyntheticShapeGenerator:
    """Generate synthetic images with colored shapes for binding tasks."""

    COLORS = {
        'red': (220, 60, 60),
        'green': (60, 180, 60),
        'blue': (60, 60, 220),
        'yellow': (220, 220, 60),
        'purple': (160, 60, 200),
        'cyan': (60, 200, 200),
    }

    SHAPES = ['circle', 'square', 'triangle']

    def __init__(self, image_size: int = 224, seed: int = 42):
        self.image_size = image_size
        self.rng = np.random.RandomState(seed)

    def draw_shape(self, draw: ImageDraw, shape: str, color: tuple,
                   x: int, y: int, size: int):
        """Draw a shape at position."""
        if shape == 'circle':
            draw.ellipse([x-size, y-size, x+size, y+size], fill=color)
        elif shape == 'square':
            draw.rectangle([x-size, y-size, x+size, y+size], fill=color)
        elif shape == 'triangle':
            points = [(x, y - size), (x - size, y + size), (x + size, y + size)]
            draw.polygon(points, fill=color)

    def create_image(self, objects: List[Dict], bg_color=(200, 200, 200)) -> Image.Image:
        """Create image with specified objects.

        objects: List of {'shape': str, 'color': str, 'position': str}
        """
        img = Image.new('RGB', (self.image_size, self.image_size), bg_color)
        draw = ImageDraw.Draw(img)

        positions = {
            'left': (self.image_size // 4, self.image_size // 2),
            'right': (3 * self.image_size // 4, self.image_size // 2),
        }

        size = self.image_size // 6

        for obj in objects:
            x, y = positions[obj['position']]
            color = self.COLORS[obj['color']]
            self.draw_shape(draw, obj['shape'], color, x, y, size)

        return img


class SameDifferentBenchmark:
    """
    Same/Different task: Structural discrimination control.

    Determines if two images share the same spatial configuration
    (shape-position bindings) despite differing colors.
    """

    def __init__(self, device: str = "cuda", num_samples: int = 500, seed: int = 42):
        self.device = device
        self.num_samples = num_samples
        self.generator = SyntheticShapeGenerator(seed=seed)
        self.transform = get_default_transform()
        self.rng = np.random.RandomState(seed)

    def generate_pair(self, same: bool) -> Tuple[Image.Image, Image.Image]:
        """Generate a pair of images that are same or different."""
        shapes = self.generator.SHAPES
        colors = list(self.generator.COLORS.keys())

        shape1 = self.rng.choice(shapes)
        shape2 = self.rng.choice(shapes)
        color1 = self.rng.choice(colors)
        color2 = self.rng.choice([c for c in colors if c != color1])

        objects1 = [
            {'shape': shape1, 'color': color1, 'position': 'left'},
            {'shape': shape2, 'color': color2, 'position': 'right'},
        ]
        img1 = self.generator.create_image(objects1)

        if same:
            new_color1 = self.rng.choice([c for c in colors if c != color1])
            new_color2 = self.rng.choice([c for c in colors if c != color2 and c != new_color1])
            objects2 = [
                {'shape': shape1, 'color': new_color1, 'position': 'left'},
                {'shape': shape2, 'color': new_color2, 'position': 'right'},
            ]
        else:
            if self.rng.rand() > 0.5:
                objects2 = [
                    {'shape': shape2, 'color': color1, 'position': 'left'},
                    {'shape': shape1, 'color': color2, 'position': 'right'},
                ]
            else:
                new_shape1 = self.rng.choice([s for s in shapes if s != shape1])
                objects2 = [
                    {'shape': new_shape1, 'color': color1, 'position': 'left'},
                    {'shape': shape2, 'color': color2, 'position': 'right'},
                ]

        img2 = self.generator.create_image(objects2)
        return img1, img2

    def evaluate(self, model: torch.nn.Module, encode_fn: Optional[Callable] = None) -> BenchmarkResult:
        """Evaluate same-different discrimination."""
        model = model.eval().to(self.device)

        distances_same = []
        distances_diff = []

        with torch.no_grad():
            for i in tqdm(range(self.num_samples // 2), desc="Same/Different"):
                img1, img2 = self.generate_pair(same=True)
                emb1 = self._encode(model, img1, encode_fn)
                emb2 = self._encode(model, img2, encode_fn)
                distances_same.append(1 - F.cosine_similarity(emb1, emb2).item())

                img1, img2 = self.generate_pair(same=False)
                emb1 = self._encode(model, img1, encode_fn)
                emb2 = self._encode(model, img2, encode_fn)
                distances_diff.append(1 - F.cosine_similarity(emb1, emb2).item())

        all_distances = distances_same + distances_diff
        labels = [1] * len(distances_same) + [0] * len(distances_diff)

        best_acc = 0
        for threshold in np.percentile(all_distances, range(0, 101, 5)):
            preds = [1 if d < threshold else 0 for d in all_distances]
            acc = sum(p == l for p, l in zip(preds, labels)) / len(labels)
            best_acc = max(best_acc, acc)

        return BenchmarkResult(
            task_name="same_different",
            accuracy=best_acc,
            num_samples=self.num_samples,
            chance_level=0.5,
        )

    def _encode(self, model, img, encode_fn):
        x = self.transform(img).unsqueeze(0).to(self.device)
        if encode_fn:
            emb = encode_fn(x)
        else:
            emb = model(x)
        if isinstance(emb, dict):
            emb = emb.get("image", emb.get("features", list(emb.values())[0]))
        if isinstance(emb, tuple):
            emb = emb[0]
        return F.normalize(emb.float(), dim=-1)


class AttributeBindingBenchmark:
    """
    Attribute Binding task: Main compositional binding benchmark.

    Query and target use DISJOINT color sets. Model must match based on
    shape-position bindings only.

    Example:
    - Query: red circle (left), blue square (right)
    - Target: GREEN circle (left), YELLOW square (right)  <- correct
    - Distractor: green square (left), yellow circle (right)  <- wrong binding
    """

    def __init__(self, device: str = "cuda", num_samples: int = 500,
                 n_distractors: int = 3, seed: int = 42):
        self.device = device
        self.num_samples = num_samples
        self.n_distractors = n_distractors
        self.generator = SyntheticShapeGenerator(seed=seed)
        self.transform = get_default_transform()
        self.rng = np.random.RandomState(seed)

    def generate_trial(self) -> Tuple[Image.Image, List[Image.Image], int]:
        """Generate query and candidates with disjoint colors."""
        shapes = self.generator.SHAPES
        colors = list(self.generator.COLORS.keys())

        shape1, shape2 = self.rng.choice(shapes, 2, replace=False)
        query_colors = self.rng.choice(colors, 2, replace=False)
        color1_q, color2_q = query_colors[0], query_colors[1]

        query_objects = [
            {'shape': shape1, 'color': color1_q, 'position': 'left'},
            {'shape': shape2, 'color': color2_q, 'position': 'right'},
        ]
        query_img = self.generator.create_image(query_objects)

        # Target uses DIFFERENT colors (disjoint)
        remaining_colors = [c for c in colors if c not in query_colors]
        target_colors = self.rng.choice(remaining_colors, 2, replace=False)
        color1_t, color2_t = target_colors[0], target_colors[1]

        correct_objects = [
            {'shape': shape1, 'color': color1_t, 'position': 'left'},
            {'shape': shape2, 'color': color2_t, 'position': 'right'},
        ]
        correct_img = self.generator.create_image(correct_objects)

        # Distractors with binding violations
        distractors = []

        # D1: swap shapes (wrong binding)
        d1_objects = [
            {'shape': shape2, 'color': color1_t, 'position': 'left'},
            {'shape': shape1, 'color': color2_t, 'position': 'right'},
        ]
        distractors.append(self.generator.create_image(d1_objects))

        # D2: swap with different colors
        other_colors = self.rng.choice([c for c in colors if c not in [color1_t, color2_t]], 2, replace=True)
        d2_objects = [
            {'shape': shape2, 'color': other_colors[0], 'position': 'left'},
            {'shape': shape1, 'color': other_colors[1], 'position': 'right'},
        ]
        distractors.append(self.generator.create_image(d2_objects))

        # D3: partial match (one correct, one wrong)
        d3_objects = [
            {'shape': shape1, 'color': color1_t, 'position': 'left'},
            {'shape': shape1, 'color': color2_t, 'position': 'right'},
        ]
        distractors.append(self.generator.create_image(d3_objects))

        # Shuffle
        candidates = [correct_img] + distractors[:self.n_distractors]
        indices = list(range(len(candidates)))
        self.rng.shuffle(indices)
        candidates = [candidates[i] for i in indices]
        correct_idx = indices.index(0)

        return query_img, candidates, correct_idx

    def evaluate(self, model: torch.nn.Module, encode_fn: Optional[Callable] = None) -> BenchmarkResult:
        """Evaluate attribute binding task."""
        model = model.eval().to(self.device)

        correct = 0
        total = 0

        with torch.no_grad():
            for _ in tqdm(range(self.num_samples), desc="Attribute Binding"):
                query_img, candidates, correct_idx = self.generate_trial()

                query_emb = self._encode(model, query_img, encode_fn)
                cand_embs = torch.cat([self._encode(model, c, encode_fn) for c in candidates], dim=0)

                sims = F.cosine_similarity(query_emb, cand_embs)
                pred_idx = sims.argmax().item()

                if pred_idx == correct_idx:
                    correct += 1
                total += 1

        return BenchmarkResult(
            task_name="attribute_binding",
            accuracy=correct / total,
            num_samples=total,
            chance_level=1.0 / (1 + self.n_distractors),
        )

    def _encode(self, model, img, encode_fn):
        x = self.transform(img).unsqueeze(0).to(self.device)
        if encode_fn:
            emb = encode_fn(x)
        else:
            emb = model(x)
        if isinstance(emb, dict):
            emb = emb.get("image", emb.get("features", list(emb.values())[0]))
        if isinstance(emb, tuple):
            emb = emb[0]
        return F.normalize(emb.float(), dim=-1)


def run_all_benchmarks(
    model: torch.nn.Module,
    encode_fn: Optional[Callable] = None,
    device: str = "cuda",
    num_samples: int = 500,
) -> Dict[str, BenchmarkResult]:
    """Run all benchmarks on a model."""
    results = {}

    benchmarks = [
        ("same_different", SameDifferentBenchmark(device=device, num_samples=num_samples)),
        ("attribute_binding", AttributeBindingBenchmark(device=device, num_samples=num_samples)),
    ]

    for name, bench in benchmarks:
        print(f"\nRunning {name}...")
        result = bench.evaluate(model, encode_fn)
        results[name] = result
        print(f"  Accuracy: {result.accuracy:.4f} (chance: {result.chance_level:.4f})")

    return results
