import torch
import torchvision.models as models
from solo.methods import METHODS
import matplotlib.pyplot as plt
import numpy as np
import os
from omegaconf import DictConfig, OmegaConf

def load_model_from_checkpoint(
    checkpoint_path,
    device,
    method: str = "vicreg",
    backbone: str = "resnet18",
    proj_hidden_dim: int = 512,
    proj_output_dim: int = 512,
    projector_type: str = "mlp3",
    model_dataset: str = "imagenet100",
):
    """
    Loads a solo-learn model from a checkpoint with configurable method/backbone/projector.

    Returns the model in eval() mode, or None on error.
    """
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)

        cfg = OmegaConf.create({
            "backbone": {
                "name": backbone,
                "kwargs": {}
            },
            "n_classes": 1000,
            "method_kwargs": {
                "proj_hidden_dim": int(proj_hidden_dim),
                "proj_output_dim": int(proj_output_dim),
                "sim_loss_weight": 25.0,
                "var_loss_weight": 25.0,
                "cov_loss_weight": 1.0,
                "radial_loss_weight": 0.0,
                "projector_type": projector_type,
            },
            "optimizer": {
                "name": "adam",
                "batch_size": 256,
                "lr": 0.001,
                "weight_decay": 1e-6,
                "classifier_lr": 0.001,
                "exclude_bias_n_norm_wd": False,
                "kwargs": {},
            },
            "scheduler": {"name": "warmup_cosine"},
            "method": method,
            "data": {
                "dataset": model_dataset,
                "train_path": "datasets/placeholder/train",
                "val_path": "datasets/placeholder/val",
                "format": "image_folder",
                "num_workers": 4,
                "num_classes": 100,
                "num_large_crops": 2,
                "num_small_crops": 0,
            },
            "max_epochs": 1,
            "devices": [0],
            "sync_batchnorm": False,
            "accelerator": "auto",
            "strategy": "auto",
            "precision": "32",
            "auto_resume": {"enabled": False},
            "knn_eval": {"enabled": False},
            "radius_hist": {"enabled": False},
            "mlp_probe": {"enabled": False},
            "performance": {"disable_channel_last": True},
        })

        if method not in METHODS:
            raise ValueError(f"Unknown method '{method}'. Available: {list(METHODS.keys())}")

        model = METHODS[method](cfg).to(device)
        state_dict = checkpoint.get("state_dict", checkpoint)
        missing = model.load_state_dict(state_dict, strict=False)
        if hasattr(missing, "missing_keys") and missing.missing_keys:
            print(f"Warning: missing keys when loading state_dict: {missing.missing_keys}")
        if hasattr(missing, "unexpected_keys") and missing.unexpected_keys:
            print(f"Warning: unexpected keys when loading state_dict: {missing.unexpected_keys}")

        model.eval()
        return model
    except FileNotFoundError:
        print(f"Error: Checkpoint file not found at {checkpoint_path}")
        return None
    except Exception as e:
        print(f"An error occurred while loading the model: {e}")
        return None

# Function to plot the t-SNE embeddings
def plot_tsne(
    tsne_results,
    labels,
    classes,
    title,
    save_path,
    silhouette=None,
    davies_bouldin=None,
    calinski_harabasz=None,
    ari=None,
    nmi=None,
    v_measure=None,
):
    """
    Plots and saves the t-SNE embeddings with class-specific colors.
    """
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    plt.figure(figsize=(16, 10))
    
    # Get unique classes and assign colors
    unique_labels = np.unique(labels)
    colors = plt.cm.get_cmap("tab10", len(unique_labels))
    
    for i, label in enumerate(unique_labels):
        indices = labels == label
        plt.scatter(
            tsne_results[indices, 0], 
            tsne_results[indices, 1], 
            color=colors(i), 
            label=classes[label],
            alpha=0.7
        )
        


    # Title
    title_str = title
    if silhouette is not None and davies_bouldin is not None and calinski_harabasz is not None:
        unsupervised_metrics = f'Sil: {silhouette:.2f} | DB: {davies_bouldin:.2f} | CH: {int(calinski_harabasz)}'
        title_str += f'\n{unsupervised_metrics}'
    if ari is not None and nmi is not None and v_measure is not None:
        supervised_metrics = f'ARI: {ari:.2f} | NMI: {nmi:.2f} | V-Measure: {v_measure:.2f}'
        title_str += f'\n{supervised_metrics}'


    plt.title(title_str)
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.legend(loc='best')
    
    # Save the plot
    plt.savefig(save_path, bbox_inches='tight')
    print(f"t-SNE plot saved to {save_path}")
    plt.close() 