import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import DatasetDict, concatenate_datasets
from tqdm.auto import tqdm
import open_clip
import fire
import pandas as pd
from typing import List
from layskip.utils.dictionaries import DATASET2LABEL_COLUMN
from layskip.utils.imagenet_utils import IMAGENET_CLASSES, IMAGENET_PROMPT_TEMPLATES
from nn_core.common import PROJECT_ROOT
from pytorch_lightning import seed_everything

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def get_openclip_text_features_ensemble(model, tokenizer, class_names: List[str], templates: List[str]):
    model.eval().to(device)
    num_classes = len(class_names)
    num_templates = len(templates)
    all_text_features = []

    print(f"Generating text features for {num_classes} classes using {num_templates} templates...")

    autocast = torch.cuda.amp.autocast if device.type == "cuda" else torch.cpu.amp.autocast

    with autocast():
        for class_name in tqdm(class_names, desc="Encoding Classes"):
            texts = [template.format(class_name.replace("_", " ")) for template in templates]
            try:
                text_tokens = tokenizer(texts).to(device)
            except Exception as e:
                raise RuntimeError(f"Failed to tokenize for class: {class_name}") from e

            class_embeddings = model.encode_text(text_tokens)
            class_embeddings_norm = F.normalize(class_embeddings, dim=-1)
            class_embedding_avg = class_embeddings_norm.mean(dim=0)
            class_embedding_final = F.normalize(class_embedding_avg, dim=-1)
            all_text_features.append(class_embedding_final)

    final_text_features = torch.stack(all_text_features)
    print(f"Generated final text features shape: {final_text_features.shape}")
    return final_text_features.cpu()


def evaluate_zero_shot(
    dataset_name: str,
    model_name: str,
    layers_to_approximate: List,
    translator_name: str,
    seed: int = 0,
    batch_size: int = 512,
    num_workers: int = 4,
):

    seed_everything(seed)

    if not IMAGENET_CLASSES:
        raise ValueError("IMAGENET_CLASSES list is empty or not loaded.")
    if not IMAGENET_PROMPT_TEMPLATES:
        print("Warning: IMAGENET_PROMPT_TEMPLATES not found or empty. Using single default.")
        templates_to_use = ["a photo of a {}."]
    else:
        templates_to_use = IMAGENET_PROMPT_TEMPLATES
    if dataset_name not in DATASET2LABEL_COLUMN:
        raise KeyError(f"Label column mapping not found for dataset '{dataset_name}' in DATASET2LABEL_COLUMN.")
    label_col_name = DATASET2LABEL_COLUMN[dataset_name]

    if not model_name.startswith("open_clip:"):
        raise ValueError(f"Expected model_name to start with 'open_clip:', but got {model_name}")

    open_clip_model_name = model_name.split(":", 1)[1]
    model_name_slug = model_name.split("/", 1)[1]

    expected_embedding_col_name = str(layers_to_approximate)
    EMBEDDINGS_DIR = PROJECT_ROOT / "data" / f"{translator_name}_skipped_embeddings" / dataset_name / model_name_slug
    embeddings_dataset = DatasetDict.load_from_disk(str(EMBEDDINGS_DIR))
    final_embedding_col_name = "embeddings"

    train_data_selected = embeddings_dataset["train"].select_columns([expected_embedding_col_name, label_col_name])
    test_data_selected = embeddings_dataset["test"].select_columns([expected_embedding_col_name, label_col_name])

    train_data_renamed = train_data_selected.rename_column(expected_embedding_col_name, final_embedding_col_name)
    test_data_renamed = test_data_selected.rename_column(expected_embedding_col_name, final_embedding_col_name)

    eval_data = concatenate_datasets([train_data_renamed, test_data_renamed])
    eval_data.set_format("torch", columns=[final_embedding_col_name, label_col_name])  # Specify columns for format
    print(f"Using combined dataset with {len(eval_data)} samples.")

    open_clip_hub_name = f"hf-hub:{open_clip_model_name}"
    try:
        model, _, _ = open_clip.create_model_and_transforms(open_clip_hub_name, device=device, jit=False)
        tokenizer = open_clip.get_tokenizer(open_clip_hub_name)
        model.eval()
    except Exception as e:
        print(f"Error loading OpenCLIP model '{open_clip_hub_name}': {e}")
        raise

    visual_projection_layer = None
    visual_projection_weight = None
    use_layer_call = False
    if hasattr(model, "visual") and hasattr(model.visual, "proj"):
        projection = model.visual.proj
        if projection is None:
            print("Visual projection 'proj' is None or not found. Assuming direct embedding comparison.")
        elif isinstance(projection, (torch.nn.Parameter, torch.Tensor)):
            visual_projection_weight = projection.to(device).float()
        elif isinstance(projection, torch.nn.Module):
            visual_projection_layer = projection.to(device).eval().float()
            use_layer_call = True
        else:
            raise TypeError(f"Unexpected type for visual.proj: {type(projection)}")
    else:
        print("Warning: Could not find visual projection 'model.visual.proj'. Assuming direct embedding comparison.")

    text_features = get_openclip_text_features_ensemble(model, tokenizer, IMAGENET_CLASSES, templates=templates_to_use)
    text_features = text_features.to(device).float()

    dataloader = DataLoader(
        eval_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    correct_predictions = 0
    total_samples = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating Batches"):
            image_embeddings_approx = batch[final_embedding_col_name].to(device, non_blocking=True).float()
            labels = batch[label_col_name].to(device, non_blocking=True)

            if use_layer_call and visual_projection_layer is not None:
                image_features_projected = visual_projection_layer(image_embeddings_approx)
            elif visual_projection_weight is not None:
                image_features_projected = image_embeddings_approx @ visual_projection_weight
            else:
                image_features_projected = image_embeddings_approx

            image_features_norm = F.normalize(image_features_projected, p=2, dim=-1)

            similarity = image_features_norm @ text_features.T  # Ensure text_features is [num_classes, embed_dim]
            predictions = similarity.argmax(dim=-1)

            correct_predictions += (predictions == labels).sum().item()
            total_samples += labels.size(0)

    accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0

    log_data = {
        "model": model_name,
        "dataset": dataset_name,
        "approximated_layers": str(layers_to_approximate),
        "translator": translator_name,
        "accuracy": accuracy,
        "total_samples": total_samples,
        "correct_predictions": correct_predictions,
    }
    results_df = pd.DataFrame([log_data])
    csv_folderpath = PROJECT_ROOT / "results"
    csv_folderpath.mkdir(parents=True, exist_ok=True)
    csv_filepath = csv_folderpath / "zero_shot.csv"
    print(f"Logging results to: {csv_filepath}")
    try:
        file_exists = csv_filepath.exists()
        results_df.to_csv(csv_filepath, mode="a", header=not file_exists, index=False)
        print("Results successfully logged.")
    except Exception as e:
        print(f"Error writing to CSV file {csv_filepath}: {e}")

    return accuracy


if __name__ == "__main__":
    fire.Fire(evaluate_zero_shot)
