import argparse
import os
import torch
from torchvision import transforms, datasets
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Subset
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import (
    silhouette_score,
    davies_bouldin_score,
    calinski_harabasz_score,
    adjusted_rand_score,
    normalized_mutual_info_score,
    v_measure_score,
)
from sklearn.cluster import KMeans
import numpy as np
from tqdm import tqdm

from utils import load_model_from_checkpoint, plot_tsne
from download_data import download_and_extract_images, IMAGENET_10_CLASSES


def build_dataset(dataset_name: str, num_images_per_class: int = 100):
    """
    Returns (dataset, class_names, input_size_tuple)
    dataset_name: 'imagenet10' or 'cifar10'
    """
    if dataset_name.lower() == "imagenet10":
        # Ensure images are present
        download_and_extract_images(IMAGENET_10_CLASSES, num_images_per_class=num_images_per_class)
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        dataset_path = "dataset/imagenet-10-subset"
        ds = datasets.ImageFolder(root=dataset_path, transform=transform)
        return ds, ds.classes, (224, 224)
    elif dataset_name.lower() == "cifar10":
        # Use native CIFAR resolution (32x32) to match CIFAR-style backbones
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]),
        ])
        base = CIFAR10(root="downloads", train=True, download=True, transform=transform)
        # Optionally cap images per class for parity with ImageNet-10 sampling
        if num_images_per_class is not None and num_images_per_class > 0:
            class_counts = {c: 0 for c in range(10)}
            indices = []
            for idx, (_, target) in enumerate(base):
                if class_counts[target] < num_images_per_class:
                    indices.append(idx)
                    class_counts[target] += 1
                if all(class_counts[c] >= num_images_per_class for c in class_counts):
                    break
            ds = Subset(base, indices)
            class_names = base.classes
        else:
            ds = base
            class_names = base.classes
        return ds, class_names, (32, 32)
    else:
        raise ValueError("dataset_name must be one of ['imagenet10', 'cifar10']")


def extract_embeddings(model, dataloader, device):
    """
    Extract both encoder ('feats') and projector ('z') embeddings.
    Returns: (emb_enc, emb_proj, labels_np)
    """
    enc_list, proj_list, labels_list = [], [], []
    model.eval()
    with torch.no_grad():
        for images, targets in tqdm(dataloader, desc="Inference"):
            images = images.to(device)
            out = model(images)
            feats = out.get("feats")
            z = out.get("z")
            if feats is not None:
                enc_list.append(feats.cpu().numpy())
            if z is not None:
                proj_list.append(z.cpu().numpy())
            labels_list.append(targets.cpu().numpy())

    labels_np = np.concatenate(labels_list, axis=0)
    emb_enc = np.concatenate(enc_list, axis=0) if enc_list else None
    emb_proj = np.concatenate(proj_list, axis=0) if proj_list else None
    return emb_enc, emb_proj, labels_np


def compute_metrics(embeddings, labels, num_classes, *, l2_normalize: bool = False, pca_dim: int = 0):
    X = embeddings
    if l2_normalize:
        norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-12
        X = X / norms
    if pca_dim and pca_dim > 0 and pca_dim < X.shape[1]:
        X = PCA(n_components=pca_dim, random_state=420).fit_transform(X)

    tsne = TSNE(n_components=2, random_state=420, perplexity=min(30, len(X) - 1))
    tsne_results = tsne.fit_transform(X)

    silhouette = silhouette_score(X, labels)
    davies_bouldin = davies_bouldin_score(X, labels)
    calinski_harabasz = calinski_harabasz_score(X, labels)

    kmeans = KMeans(n_clusters=num_classes, random_state=420, n_init=10)
    predicted_labels = kmeans.fit_predict(X)

    ari = adjusted_rand_score(labels, predicted_labels)
    nmi = normalized_mutual_info_score(labels, predicted_labels)
    v_measure = v_measure_score(labels, predicted_labels)

    return tsne_results, silhouette, davies_bouldin, calinski_harabasz, ari, nmi, v_measure


def main(args):
    # Device
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")

    # Dataset
    dataset, class_names, _ = build_dataset(args.dataset_name, num_images_per_class=args.num_images_per_class)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

    # Validate inputs
    assert len(args.file_paths) == len(args.names), "file_paths and names must have the same length"

    # Loop over models
    for ckpt_path, friendly_name in zip(args.file_paths, args.names):
        print(f"Loading model: {friendly_name} from {ckpt_path}")
        model = load_model_from_checkpoint(
            ckpt_path,
            device,
            method=args.model_method,
            backbone=args.backbone,
            proj_hidden_dim=args.proj_hidden_dim,
            proj_output_dim=args.proj_output_dim,
            projector_type=args.projector_type,
            model_dataset=args.pretrained_dataset,
        )
        if model is None:
            print(f"Skipping {ckpt_path} due to load failure.")
            continue

        # Extract encoder and projector embeddings
        print("Extracting embeddings (encoder and projector)...")
        emb_enc, emb_proj, labels = extract_embeddings(model, dataloader, device)

        # Encoder plot (if available)
        if emb_enc is not None:
            print("t-SNE for encoder embeddings...")
            res = compute_metrics(
                emb_enc,
                labels,
                num_classes=len(class_names),
                l2_normalize=args.l2_normalize,
                pca_dim=args.pca_dim,
            )
            tsne_res, sil, dbi, ch, ari, nmi, vms = res
            title = f"{friendly_name} - Encoder"
            save_path = os.path.join("tsne_visualization", "plots", f"{friendly_name}_encoder.png")
            plot_tsne(
                tsne_res,
                labels,
                class_names,
                title,
                save_path,
                silhouette=sil,
                davies_bouldin=dbi,
                calinski_harabasz=ch,
                ari=ari,
                nmi=nmi,
                v_measure=vms,
            )

        # Projector plot (if available)
        if emb_proj is not None:
            print("t-SNE for projector embeddings...")
            res = compute_metrics(
                emb_proj,
                labels,
                num_classes=len(class_names),
                l2_normalize=args.l2_normalize,
                pca_dim=args.pca_dim,
            )
            tsne_res, sil, dbi, ch, ari, nmi, vms = res
            title = f"{friendly_name} - Projector"
            save_path = os.path.join("tsne_visualization", "plots", f"{friendly_name}_projector.png")
            plot_tsne(
                tsne_res,
                labels,
                class_names,
                title,
                save_path,
                silhouette=sil,
                davies_bouldin=dbi,
                calinski_harabasz=ch,
                ari=ari,
                nmi=nmi,
                v_measure=vms,
            )


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='t-SNE visualization for multiple checkpoints (encoder and projector).')
    # Inputs
    parser.add_argument('--file_paths', type=str, nargs='+', required=True, help='List of checkpoint file paths.')
    parser.add_argument('--names', type=str, nargs='+', required=True, help='List of display names matching file_paths.')
    # Dataset controls
    parser.add_argument('--dataset_name', type=str, choices=['imagenet10', 'cifar10'], default='imagenet10', help='Dataset to use for inference.')
    parser.add_argument('--num_images_per_class', type=int, default=100, help='Limit images per class for faster t-SNE.')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for inference.')
    parser.add_argument('--l2_normalize', action='store_true', help='L2-normalize embeddings before metrics and t-SNE.')
    parser.add_argument('--pca_dim', type=int, default=0, help='If >0, apply PCA to this many dims before t-SNE and metrics.')
    # Model config (for proper instantiation)
    parser.add_argument('--model_method', type=str, default='vicreg', help='Method name in solo METHODS (e.g., vicreg, radialvicreg).')
    parser.add_argument('--backbone', type=str, default='resnet18', help='Backbone architecture (e.g., resnet18, resnet50).')
    parser.add_argument('--proj_hidden_dim', type=int, default=2048, help='Projector hidden dim.')
    parser.add_argument('--proj_output_dim', type=int, default=2048, help='Projector output dim.')
    parser.add_argument('--projector_type', type=str, default='mlp', help='Projector type (identity, mlp1..mlp5, mlp).')
    parser.add_argument('--pretrained_dataset', type=str, choices=['imagenet100', 'cifar100'], default='imagenet100', help='Dataset the checkpoint was trained on (affects backbone tweaks).')

    args = parser.parse_args()
    main(args)