import os

import open_clip
import torch
from tqdm import tqdm

from src.datasets.registry import get_dataset
from src.datasets.templates import get_templates
from src.modeling import ClassificationHead, ImageEncoder


def build_classification_head(model, dataset_name, template, data_location, device, seed=None):
    template = get_templates(dataset_name)

    logit_scale = model.logit_scale
    dataset = get_dataset(dataset_name, None, location=data_location, seed=seed)
    model.eval()
    model.to(device)

    print("Building classification head.")
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(dataset.classnames):
            texts = []
            for t in template:
                texts.append(t(classname))
            texts = open_clip.tokenize(texts).to(device)  # tokenize
            embeddings = model.encode_text(texts)  # embed with text encoder
            embeddings /= embeddings.norm(dim=-1, keepdim=True)

            embeddings = embeddings.mean(dim=0, keepdim=True)
            embeddings /= embeddings.norm()

            zeroshot_weights.append(embeddings)

        zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device)
        zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2)

        zeroshot_weights *= logit_scale.exp()

        zeroshot_weights = zeroshot_weights.squeeze().float()
        zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1)

    classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights)

    return classification_head



def build_classification_multihead(image_encoder, datasets, data_location, device):
    total_templates = []
    all_classnames = []
    logit_scale = image_encoder.logit_scale
    image_encoder.eval()
    image_encoder.to(device)

    # Gather all classnames and templates from datasets
    for dataset_name in datasets:
        template = get_templates(dataset_name)
        dataset = get_dataset(dataset_name, None, location=data_location)
        total_templates.extend(template)
        all_classnames.extend(dataset.classnames)

    print("Building classification Multi-head.")
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(all_classnames):
            texts = [t(classname) for t in total_templates]
            texts = open_clip.tokenize(texts).to(device)
            embeddings = image_encoder.encode_text(texts)
            embeddings /= embeddings.norm(dim=-1, keepdim=True)
            embeddings = embeddings.mean(dim=0, keepdim=True)
            embeddings /= embeddings.norm()
            zeroshot_weights.append(embeddings)

        zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device)
        zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2)
        zeroshot_weights *= logit_scale.exp()
        zeroshot_weights = zeroshot_weights.squeeze().float()
        zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1)

    return ClassificationHead(normalize=True, weights=zeroshot_weights)



def get_classification_head(args, dataset):
    if not dataset.endswith("Val"):
        # We want to load the head for the validation set always to be consistent with the one generated at training time.
        dataset += "Val"

    filename = os.path.join(args.save, f"head_{dataset}.pt")
    if os.path.exists(filename):
        print(f"Classification head for {args.model} on {dataset} exists at {filename}")
        return ClassificationHead.load(filename)
    print(
        f"Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch."  # noqa: E501
    )
    model = ImageEncoder(args, keep_lang=True).model
    template = get_templates(dataset)
    classification_head = build_classification_head(
        model, dataset, template, args.data_location, args.device, seed=args.seed
    )
    os.makedirs(args.save, exist_ok=True)
    classification_head.save(filename)
    return classification_head


def get_multihead_classification(args, datasets):
    filename = os.path.join(args.save, f"multihead_{datasets}.pt")
    if os.path.exists(filename):
        print(f"Multihead classification head for {args.model} on {datasets} exists at {filename}")
        return ClassificationHead.load(filename)
    print(
        f"Did not find multihead classification head for {args.model} on {datasets} at {filename}, building one from scratch."  # noqa: E501
    )
    model = ImageEncoder(args, keep_lang=True).model
    classification_head = build_classification_multihead(
        model, datasets, args.data_location, args.device
    )
    os.makedirs(args.save, exist_ok=True)
    classification_head.save(filename)
    return classification_head