import torch
import math
from typing import Dict, Optional, Sequence, Union
from pathlib import Path
import json
from src_clean.model_factory import load_model
import csv
def _sample_uniform_ball(num_points: int,dim: int,radius: float,device: torch.device,generator: Optional[torch.Generator]) -> torch.Tensor:
    """Sample points uniformly from a d-ball of radius R centered at the origin."""
    directions = torch.randn((num_points, dim), device=device, generator=generator)
    norms = directions.norm(dim=1, keepdim=True).clamp_min(1e-12)
    directions = directions / norms
    radii = torch.rand((num_points, 1), device=device, generator=generator).pow(1.0 / dim) * radius
    return directions * radii

def _sample_uniform_rectangle(num_points: int,bounds: torch.Tensor,device: torch.device,generator: Optional[torch.Generator],) -> torch.Tensor:
    """Sample points uniformly from a rectangular region [min, max]^d."""
    dim = bounds.shape[0]
    points = torch.rand((num_points, dim), device=device, generator=generator)
    mins = bounds[:, 0]
    maxs = bounds[:, 1]
    points = points * (maxs - mins).unsqueeze(0) + mins.unsqueeze(0)
    return points

def estimate_voronoi_cell_volumes(
    decoder_matrix: Union[torch.Tensor, "np.ndarray"],
    region: str = "rectangle",
    radius: Optional[float] = None,
    bounds: Optional[torch.Tensor] = None,
    subset_indices: Optional[Sequence[int]] = None,
    num_samples: int = 50000,
    tokens_on_columns: bool = False,
    device: Optional[Union[str, torch.device]] = None,
    chunk_size_tokens: int = 2048,
    point_batch_size: int = 8192,
    seed: Optional[int] = None,
) -> Dict[str, object]:
    """
    Monte Carlo estimate of Voronoi cell volumes for token embeddings within a region.

    Args:
        decoder_matrix: [num_tokens, dim] (tokens on rows) or [dim, num_tokens] (tokens on columns).
        region: Sampling region ("ball", or "rectangle").
        radius: Sphere radius; used when region == "ball".
        bounds: [dim, 2] tensor with [min, max] per dimension; used when region == "rectangle".
        subset_indices: Optional sequence of token indices to restrict computation.
        num_samples: Number of random points to sample.
        tokens_on_columns: Set True if decoder_matrix columns correspond to tokens.
        device: Torch device string or torch.device.
        chunk_size_tokens: How many token embeddings to compare at once.
        point_batch_size: How many sample points to process at once.
        seed: Optional manual seed for reproducibility.
        convex_hull_simplex_size: Number of tokens to mix per sample for convex_hull region.

    Returns:
        Dict with volumes, proportions, counts, and metadata.
    """
    region = region.lower()
    if region not in {"ball", "rectangle"}:
        raise ValueError("region must be 'ball', or 'rectangle'")
    if region == "ball" and radius is None:
        raise ValueError("radius must be provided when region == 'ball'")
    if region == "rectangle" and bounds is None:
        raise ValueError("bounds must be provided when region == 'rectangle'")
    if num_samples <= 0:
        raise ValueError("num_samples must be positive")

    dev = torch.device(device) if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator = torch.Generator(device=dev) if dev.type == "cuda" else torch.Generator()
    if seed is not None:
        generator.manual_seed(seed)

    matrix = torch.as_tensor(decoder_matrix, device=dev, dtype=torch.float32)
    tokens = matrix.t() if tokens_on_columns else matrix
    if tokens.dim() != 2:
        raise ValueError(f"decoder_matrix must be 2D, got shape {tuple(tokens.shape)}")

    num_tokens, dim = tokens.shape
    if subset_indices is not None:
        idx = torch.as_tensor(list(subset_indices), device=dev, dtype=torch.long)
        tokens = tokens[idx]
        selected_token_ids = [int(i) for i in idx.tolist()]
    else:
        selected_token_ids = list(range(num_tokens))

    if tokens.numel() == 0:
        raise ValueError("No tokens to process after applying subset_indices")

    token_norm2 = (tokens ** 2).sum(dim=1)
    chunk = chunk_size_tokens if chunk_size_tokens > 0 else tokens.shape[0]
    pt_batch = point_batch_size if point_batch_size > 0 else num_samples

    counts = torch.zeros(tokens.shape[0], device=dev, dtype=torch.long)
    processed = 0

    if region == "ball":
        def sample_points(batch_size: int) -> torch.Tensor:
            return _sample_uniform_ball(batch_size, dim, radius, dev, generator)

    elif region == "rectangle":
        bounds = torch.as_tensor(bounds, device=dev, dtype=torch.float32)
        if bounds.shape != (dim, 2):
            raise ValueError(f"bounds must have shape [dim, 2], got {bounds.shape}")

        def sample_points(batch_size: int) -> torch.Tensor:
            return _sample_uniform_rectangle(batch_size, bounds, dev, generator)

    while processed < num_samples:
        bs = min(pt_batch, num_samples - processed)
        points = sample_points(bs)
        point_norm2 = (points ** 2).sum(dim=1)

        best_dist2 = None
        best_idx = None
        for start in range(0, tokens.shape[0], chunk):
            end = min(start + chunk, tokens.shape[0])
            tok_chunk = tokens[start:end]
            dist2 = point_norm2[:, None] + token_norm2[start:end][None, :] - 2.0 * (points @ tok_chunk.t())
            chunk_best, chunk_idx = torch.min(dist2, dim=1)
            if best_dist2 is None:
                best_dist2 = chunk_best
                best_idx = chunk_idx + start
            else:
                improved = chunk_best < best_dist2
                best_dist2 = torch.where(improved, chunk_best, best_dist2)
                best_idx = torch.where(improved, chunk_idx + start, best_idx)

        batch_counts = torch.bincount(best_idx, minlength=tokens.shape[0])
        counts[: batch_counts.shape[0]] += batch_counts
        processed += bs

    proportions = counts.double() / float(num_samples)
    log_proportions = torch.full_like(proportions, float("-inf"))
    positive_mask = proportions > 0
    log_proportions[positive_mask] = torch.log(proportions[positive_mask])

    proportion_per_token = {selected_token_ids[i]: float(proportions[i].item()) for i in range(len(selected_token_ids))}
    counts_per_token = {selected_token_ids[i]: int(counts[i].item()) for i in range(len(selected_token_ids))}

    return {
        "proportion_per_token": proportion_per_token,
        "counts": counts_per_token,
        "region": region,
        "num_samples": num_samples,
        "dim": dim,
        "selected_token_ids": selected_token_ids,
    }


def estimate_volumes_for_model(
    model_name: str,
    decoder_matrix: torch.Tensor,
    tokenizer,
    tokens_on_columns: bool,
    subset_ids: Optional[list] = None,
    region: str = "rectangle",
    radius: float = 10.0,
    num_samples: int = 50_000_000,
    chunk_size_tokens: int = 2048,
    point_batch_size: int = 4096,
    seed: int = 0,
    device: str = "cuda",
) -> Dict[str, object]:
    """Estimate Voronoi volumes for a model."""
    
    # Load bounds from file if region is "rectangle"
    bounds_file = Path(f"data/rects_inf/{model_name.split('/')[-1]}.txt")
    bounds_list = []
    with open(bounds_file, "r") as f:
        for line in f:
            min_val, max_val = map(float, line.strip().split())
            bounds_list.append([min_val, max_val])
    bounds = torch.tensor(bounds_list, dtype=decoder_matrix.dtype, device=device)
    print(f"Loaded bounds from {bounds_file}; shape: {bounds.shape}")

    result = estimate_voronoi_cell_volumes(
        decoder_matrix=decoder_matrix,
        region=region,
        radius=radius if region == "ball" else None,
        bounds=bounds,
        subset_indices=subset_ids,
        num_samples=num_samples,
        tokens_on_columns=tokens_on_columns,
        device=device,
        chunk_size_tokens=chunk_size_tokens,
        point_batch_size=point_batch_size,
        seed=seed,
    )

    print({k: result[k] for k in ["region", "dim", "num_samples"]})
    return result


def token_str(tid: int, tokenizer) -> str:
    """Convert token ID to string representation."""
    if tokenizer is None:
        return f"<id_{tid}>"
    try:
        return tokenizer.decode([tid])
    except Exception:
        return f"<id_{tid}>"


def save_results_csv(
    result: Dict[str, object],
    tokenizer,
    output_csv: Path,
) -> None:
    """Save Voronoi results to CSV."""
    
    items = []
    for tid in result["proportion_per_token"].keys():
        items.append({
            "token_id": tid,
            "token": token_str(tid, tokenizer),
            "samples": result["counts"][tid],
            "proportion": result["proportion_per_token"][tid],
        })
    
    items_sorted = sorted(items, key=lambda x: x["proportion"], reverse=True)
    
    output_csv.parent.mkdir(parents=True, exist_ok=True)
    with open(output_csv, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=["token_id", "token", "samples", "proportion"])
        writer.writeheader()
        writer.writerows(items_sorted)
    
    print(f"Saved CSV to {output_csv}")


def plot_and_save_figure(
    result: Dict[str, object],
    model_name: str,
    output_dir: Path = Path("data/voronoi_figures"),
) -> None:
    """Create log-log scatter and Pareto plot, save as SVG."""
    import matplotlib.pyplot as plt
    
    props = torch.tensor(
        [v for v in result["proportion_per_token"].values()],
        dtype=torch.float64
    )
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    sorted_props, _ = torch.sort(props, descending=True)
    ranks = torch.arange(1, len(sorted_props) + 1, dtype=torch.int64)
    positive = sorted_props > 0
    
    # Log-log scatter
    ax1.scatter(
        ranks[positive].cpu().numpy(),
        sorted_props[positive].cpu().numpy(),
        s=8,
        alpha=0.6,
        color="steelblue",
        edgecolors="none",
    )
    ax1.set_xscale("log")
    ax1.set_yscale("log")
    ax1.set_xlabel("token rank (sorted by proportion)")
    ax1.set_ylabel("proportion of volume")
    ax1.set_title("Log-log rank vs proportion")
    ax1.grid(alpha=0.3, which="both")
    
    # Pareto plot
    cum = torch.cumsum(sorted_props, dim=0)
    frac_tokens = torch.linspace(0, 1, steps=len(sorted_props))
    
    ax2.plot(frac_tokens.cpu().numpy(), cum.cpu().numpy(), label='cumulative fraction of volume')
    ax2.axhline(0.8, color='red', linestyle='--', alpha=0.6, label='80% volume')
    ax2.set_xlabel('fraction of tokens (sorted by proportion)')
    ax2.set_ylabel('cumulative proportion')
    ax2.set_ylim(0, 1.05)
    ax2.legend()
    ax2.grid(alpha=0.3)
    ax2.set_title('Pareto-style cumulative coverage')
    
    # Compute Pareto point
    pareto_idx = torch.nonzero(cum >= 0.8, as_tuple=False)
    if pareto_idx.numel() > 0:
        k = pareto_idx[0].item() + 1
        print(f'Tokens needed for 80% of volume: {k} ({k/len(sorted_props):.2%} of tokens)')
    
    plt.tight_layout()
    output_dir.mkdir(parents=True, exist_ok=True)
    output_path = output_dir / f"{model_name.replace('/', '_')}.svg"
    plt.savefig(output_path, format='svg', dpi=150)
    print(f"Saved figure to {output_path}")
    plt.close()


def main():
    """Main execution function for all models."""
    models = [
        'EleutherAI/pythia-160m',
        'EleutherAI/pythia-410m',
        'EleutherAI/pythia-1b',
        'EleutherAI/pythia-1.4b',
        'EleutherAI/pythia-2.8b',
        'Qwen/Qwen2.5-0.5B',
        'Qwen/Qwen2.5-1.5B',
        'meta-llama/Llama-3.2-1B',
        'meta-llama/Llama-3.2-3B',
        'meta-llama/Llama-3.1-8B',
        'google/gemma-3-1b-pt',
        'google/gemma-3-270m',
    ]
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    for model_name in models:
        print(f"\n{'='*60}")
        print(f"Processing: {model_name}")
        print(f"{'='*60}")
        
        try:
            # Load model and decoder matrix
            model = load_model(model_name)
            tokenizer = model.tokenizer
            decoder_matrix = model.get_output_projection_matrix()
            tokens_on_columns = False
            print(f"Loaded model {model_name}; decoder matrix shape: {tuple(decoder_matrix.shape)}")
            
            # Estimate volumes
            result = estimate_volumes_for_model(
                model_name=model_name,
                decoder_matrix=decoder_matrix,
                tokenizer=tokenizer,
                tokens_on_columns=tokens_on_columns,
                subset_ids=None,
                region="rectangle",
                num_samples=50_000_000,
                device=device,
            )
            
            # Save CSV
            output_csv = Path(f"data/voronoi_results/{model_name.replace('/', '_')}.csv")
            save_results_csv(result, tokenizer, output_csv)
            
            # Plot and save figure
            plot_and_save_figure(result, model_name.replace('/', '_'))
            
        except Exception as e:
            print(f"Error processing {model_name}: {e}")
            continue


if __name__ == "__main__":
    main()
