import math
from typing import List, Tuple

import torch
import torch.nn as nn
import os
import pickle
import torch.nn.functional as F

from model.embed_with_prompts import (load_class2concepts,
                                      get_text_features, get_cache_path)

from clip import load, tokenize
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
from data.imagnet_prompts import imagenet_classes
from data.fewshot_datasets import fewshot_datasets

from data.cls_to_names import *

from data.imagnet_prompts import imagenet_templates

_tokenizer = _Tokenizer()

DOWNLOAD_ROOT = '~/.cache/clip'


class ClipImageEncoder(nn.Module):
    """
    CLIP Image Encoder with an optional classification head.
    """

    def __init__(self, device, arch="ViT-L/14", image_resolution=224, n_class=1000):
        super(ClipImageEncoder, self).__init__()
        clip_model, embed_dim, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT)
        self.encoder = clip_model.visual
        # Delete transformer to free up memory if not needed by the image encoder itself
        del clip_model.transformer
        torch.cuda.empty_cache()

        self.cls_head = nn.Linear(embed_dim, n_class)

    @property
    def dtype(self):
        """
        Returns the data type of the encoder's first convolutional layer.
        """
        return self.encoder.conv1.weight.dtype

    def forward(self, image):
        """
        Forward pass for the image encoder.
        """
        x = self.encoder(image.type(self.dtype))
        output = self.cls_head(x)
        return output


class TextEncoder(nn.Module):
    """
    Custom Text Encoder module, wrapping CLIP's text transformer components.
    """

    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        """
        Forward pass for the text encoder.
        Args:
            prompts (torch.Tensor): Raw prompt embeddings.
            tokenized_prompts (torch.Tensor): Tokenized prompts for EOT token extraction.
        Returns:
            torch.Tensor: Text features.
        """
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

class PromptLearner_withTemplates(nn.Module):
    """
    Manages prompt generation for CLIP, supporting fixed templates and CoOp-style learning.
    """

    def __init__(self, args, clip_model):
        super().__init__()
        self.dtype = clip_model.dtype
        self.device = clip_model.visual.conv1.weight.device
        self.args = args
        self.clip_model = clip_model
        self.n_ctx = args.n_ctx

        if args.coop and (args.with_templates or args.with_concepts):
            raise ValueError("Model parameter error: You use both the CoOp and the other templates")

        self.n_cls = 0
        self.classnames = []
        self.tokenized_prompts = None  # Will become Dict[str, Tensor]
        self.token_prefix = None
        self.token_suffix = None
        self.ctx = None

    def reset_classnames(self, args, classnames):
        """
        Resets and reinitializes prompts based on new class names and arguments.
        Args:
            args: Command-line arguments.
            classnames (list): List of class names.
        """
        self.n_cls = len(classnames)
        print(f'=> Class Number: {self.n_cls}')

        self.classnames = [name.replace("_", " ") for name in classnames]

        if args.coop:
            self.ctx = torch.load(args.load)['state_dict']['ctx']
            with torch.no_grad():
                base_prompts = [f'a photo of a {name}.' for name in self.classnames]
                tokenized_base_prompts = torch.cat([tokenize(p) for p in base_prompts]).to(self.device)
                class_embedding = self.clip_model.token_embedding(tokenized_base_prompts).type(self.dtype)
                self.token_prefix = class_embedding[:, :1, :]  # Shape [n_cls, 1, dim]
                self.token_suffix = class_embedding[:, 1 + self.n_ctx:, :]  # Shape [n_cls, *, dim]
            self.tokenized_prompts = None  # Will be created dynamically in forward()

        else:
            # Fixed template or context-init prompt learner
            self.tokenized_prompts = {}  # Dict[class_name -> tokenized Tensor]
            for name in self.classnames:
                if args.with_templates:
                    prompt_templates = [template.format(name) for template in imagenet_templates]
                else:
                    ctx_init_str = args.ctx_init.replace("_", " ")
                    prompt_templates = [f'{ctx_init_str} {name}.']

                tokenized = torch.cat([tokenize(p) for p in prompt_templates]).to(self.device)
                self.tokenized_prompts[name] = tokenized

    def forward(self):
        """
        Generates prompts based on the current learning mode (CoOp or fixed).
        Returns:
            Dict[str, torch.Tensor]: tokenized prompts per class name.
        """
        if self.args.coop:
            prompts_dict = {}
            if self.ctx is None:
                raise ValueError("CoOp context (self.ctx) not initialized. Call reset_classnames first.")

            ctx_expanded = self.ctx.unsqueeze(0) if self.ctx.ndim == 2 else self.ctx  # [1, n_ctx, dim] or [n_cls, n_ctx, dim]

            for idx, name in enumerate(self.classnames):
                prefix_for_class = self.token_prefix[idx:idx + 1, :, :]
                suffix_for_class = self.token_suffix[idx:idx + 1, :, :]
                ctx_for_class = ctx_expanded if ctx_expanded.shape[0] == 1 else ctx_expanded[idx:idx + 1]

                prompts = torch.cat(
                    [
                        prefix_for_class,             # [1, 1, dim]
                        ctx_for_class.to(self.device),  # [1, n_ctx, dim]
                        suffix_for_class              # [1, *, dim]
                    ],
                    dim=-2
                )
                prompts_dict[name] = prompts
            return prompts_dict

        else:
            if self.tokenized_prompts is None:
                raise ValueError("Fixed prompts not initialized. Call reset_classnames first.")
            return self.tokenized_prompts  # Dict[str, Tensor]


class ClipTestTimeTuning(nn.Module):
    """
    Main model for CLIP Test-Time Tuning, integrating image/text encoders and prompt learning.
    """

    def __init__(self, args, device, classnames, criterion='cosine', arch="ViT-L/14"):
        super(ClipTestTimeTuning, self).__init__()
        clip_model, self.emb_dim, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT)
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale.data
        self.args = args
        self.dataset = None
        self.device = device
        self.clip = clip_model
        self.classnames = classnames

        self.prompt_learner = PromptLearner_withTemplates(args, clip_model)

        self.criterion = criterion

        self.text_features = None  # Will be initialized later
        self.class_number = 0  # Will be set in reset_classnames

        self.per_label = args.per_label

    @property
    def dtype(self):
        """
        Returns the data type of the image encoder.
        """
        return self.image_encoder.conv1.weight.dtype

    def reset_classnames(self, args, classnames, dataset):
        """
        Resets class names and updates the prompt learner.
        Args:
            args: Command-line arguments.
            classnames (list): List of class names.
            dataset (str): Name of the current dataset.
        """
        # Ensure classnames are formatted for the prompt learner
        formatted_classnames = [name.replace("_", " ") for name in classnames]
        self.prompt_learner.reset_classnames(args, formatted_classnames)
        self.dataset = dataset
        self.classnames = formatted_classnames
        self.class_number = self.prompt_learner.n_cls

    def init_text_features(self):
        """
        Initializes text features based on the current pooling type and prompts.
        Caches features if not already present.
        """
        if self.text_features is None:
            print(
                f'=> Initialize text features, templates: {self.args.with_templates}, pooling: {self.args.pooling_type}')

        self.path = get_cache_path(self.args, self.dataset)

        text_embed_tuple = get_text_features(self.args, self.clip, self.text_encoder, self.prompt_learner,
                                             self.dataset, self.classnames, self.device)
        class_emb, concept_emb = text_embed_tuple[0], text_embed_tuple[1]

        # Apply pooling strategy
        if 'mean' in self.args.pooling_type:
            text_embed = (class_emb + concept_emb) / 2
        elif 'macro' in self.args.pooling_type:
            class2concepts = load_class2concepts(path=self.path['dict'], classnames=self.classnames)
            num_class2_concepts = torch.tensor(
                [len(class_lst) for c, class_lst in class2concepts.items()]).unsqueeze(-1).to(self.device)
            num_templates = torch.ones_like(num_class2_concepts) * (self.prompt_learner.tokenized_prompts.size(
                0) // self.class_number if self.prompt_learner.tokenized_prompts is not None else 1)  # Fallback to 1 if not calculable

            text_embed = (concept_emb * num_class2_concepts + class_emb * num_templates) / (
                    num_class2_concepts + num_templates)
        elif 'class' in self.args.pooling_type:
            text_embed = class_emb
        elif 'concept' in self.args.pooling_type:
            text_embed = concept_emb
        else:
            raise ValueError("ERROR: pooling type not found")

        self.text_features = text_embed / text_embed.norm(dim=-1, keepdim=True)

        del text_embed, class_emb, concept_emb  # Free up memory

    def get_features(self, image):
        """
        Encodes image features and retrieves pre-initialized text features.
        Args:
            image (torch.Tensor): Input image tensor.
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Image features and text features.
        """
        if self.text_features is None:
            raise RuntimeError("Text features not initialized. Call init_text_features() first.")

        with torch.no_grad():
            image_features = self.image_encoder(image.type(self.dtype))
            text_features = self.text_features  # Already normalized in init_text_features

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_features.detach().clone(), text_features.detach().clone()

    def encode_image(self, image):
        """
        Encodes image features only, used for TTA loops where text features are updated externally.
        Args:
            image (torch.Tensor): Input image tensor.
        Returns:
            torch.Tensor: Normalized image features.
        """
        with torch.no_grad():
            image_features = self.image_encoder(image.type(self.dtype))
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_features.detach().clone()

    def get_logits(self, image_features, text_features):
        """
        Calculates logits from image and text features.
        Args:
            image_features (torch.Tensor): Normalized image features.
            text_features (torch.Tensor): Normalized text features.
        Returns:
            torch.Tensor: Logits.
        """
        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()
        return logits

    def forward(self, image):
        """
        Forward pass for the entire model.
        Args:
            image (torch.Tensor): Input image tensor.
        Returns:
            torch.Tensor: Logits.
        """
        image_features, text_features = self.get_features(image)

        if self.args.shifter:
            # The text_shifter is applied externally in test_time_adapt_eval,
            # this 'if' block might be a remnant or for a different mode.
            # Keeping it as-is for now, but noting it for potential future refactoring.
            if hasattr(self, 'text_shifter') and self.text_shifter is not None:
                text_features = self.text_shifter(text_features)
            else:
                # If shifter is enabled but not set up as part of this model directly for this forward
                # (which is true given how `main.py` passes it),
                # this block will effectively do nothing unless self.text_shifter is defined.
                pass

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()
        return logits


def get_coop(args, clip_arch, test_set, device, learned_cls=False):
    """
    Factory function to get the ClipTestTimeTuning model.
    Args:
        args: Command-line arguments.
        clip_arch (str): CLIP architecture to use.
        test_set (str): Name of the test dataset.
        device (str): Device to load the model on.
        learned_cls (bool): Flag for Bongard specific class handling.
    Returns:
        ClipTestTimeTuning: The initialized model.
    """
    if test_set in fewshot_datasets:
        classnames = eval(f"{test_set.lower()}_classes")
    elif test_set == 'bongard':
        classnames = ['X', 'X'] if learned_cls else ['True', 'False']
    else:
        classnames = imagenet_classes

    model = ClipTestTimeTuning(args, device, classnames, arch=clip_arch)
    return model