from torch._tensor import Tensor


import torch
import numpy as np
import random
from torch.nn import functional as F
from queue import PriorityQueue
from collections import defaultdict
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from methods.base import TTAMethod
from utils.registry import ADAPTATION_REGISTRY
from llava_basic import generate_caption  # Import LLaVA captioning

import re

def extract_caption(text):
    match = re.search(r'ASSISTANT:\s*(.*)', text)
    return match.group(1) if match else None

def select_confident_samples(logits, top):
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
    return logits[idx], idx


def avg_entropy(outputs):
    logits = outputs - outputs.logsumexp(dim=-1, keepdim=True)
    avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0])
    min_real = torch.finfo(avg_logits.dtype).min
    avg_logits = torch.clamp(avg_logits, min=min_real)
    return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)


@ADAPTATION_REGISTRY.register()
class TPT(TTAMethod):
    def __init__(self, cfg, model, num_classes):
        super().__init__(cfg, model, num_classes)

        self.selection_p = cfg.TPT.SELECTION_P
        self.scaler = torch.cuda.amp.GradScaler(init_scale=1000)
        self.c = 0
        self.K = 5
        self.num_classes = num_classes
        self.push_id = 0
        self.priority_queues = defaultdict(lambda: PriorityQueue())

    def forward(self, x,  y, global_state_dict):
        self.c += 1
        if self.c % 100 == 0:
            print(f"[INFO] Processed {self.c} samples.")

        if self.episodic:
            self.model.reset()
            self.optimizer.load_state_dict(self.optimizer_state)

        x = torch.cat(x, dim=0)
        x = self.model.normalize(x.type(self.model.dtype))

        with torch.cuda.amp.autocast():
            img_features = self.model.image_encoder(x)
            img_features = img_features / img_features.norm(dim=-1, keepdim=True)

        #print("before: ", self.model.prompt_learner.prompt_prefix)
        # Generate caption using a chosen mode (change "captions" to "MaxCaptions" or "DomainVisualCaptions" if needed)
        # caption = generate_caption(x[0])
        # caption = extract_caption(caption)
        # print(f"Generated Caption: {caption}")

        # # Update the prompt in PromptLearner
        # self.model.prompt_learner.update_prompt_prefix(caption)


        # Save prompt before tuning
        with torch.no_grad():
            prompt_before = self.model.prompt_learner.ctx.clone()


        selected_idx = None
        for _ in range(self.steps):
            selected_idx, avg_entropy_score = self.forward_and_adapt(img_features, selected_idx, y, global_state_dict)

        with torch.cuda.amp.autocast():
            text_features = self.model.get_text_features()
            output = self.model.logit_scale.exp() * img_features[:1] @ text_features.t()

        # Cosine similarity between before and after prompt
        prompt_after = self.model.prompt_learner.ctx
        cos_sim = F.cosine_similarity(prompt_before.view(-1), prompt_after.view(-1), dim=0)
        cos_sim = (cos_sim + 1) / 2
        #print(f"[DEBUG] Prompt Cosine Similarity (normalized): {cos_sim.item():.4f}")

        # Priority queue logic
        pred_class = output.argmax(dim=1).item()
        entropy_val = avg_entropy(output).item()
        prompt_vec = prompt_after.detach().cpu()

        pq = self.priority_queues[pred_class]
        entry = (-entropy_val, self.push_id, prompt_vec)
        self.push_id += 1

        if pq.qsize() < self.K:
            pq.put(entry)
        else:
            worst_entry = pq.get()
            if -worst_entry[0] > entropy_val:
                pq.put(entry)
            else:
                pq.put(worst_entry)

        if self.c % 20==0:
            self.plot_tsne_of_prompts()

        return output

    @torch.enable_grad()
    def forward_and_adapt(self, img_features, selected_idx, y, global_state_dict):
        with torch.cuda.amp.autocast():
            text_features = self.model.get_text_features()
            logits = self.model.logit_scale.exp() * img_features @ text_features.t()

            if selected_idx is not None:
                logits = logits[selected_idx]
            else:
                logits, selected_idx = select_confident_samples(logits, self.selection_p)

            loss = avg_entropy(logits)

            fed_loss= self.fed_loss(global_state_dict)
            loss += fed_loss

        self.optimizer.zero_grad()
        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        return selected_idx, loss

    def configure_model(self):
        self.model.eval()
        self.model.requires_grad_(False)
        for name, param in self.model.named_parameters():
            if "prompt_learner" in name and not "token_embedding" in name:
                param.requires_grad_(True)

    def collect_params(self):
        params, names = [], []
        for name, param in self.model.named_parameters():
            if "prompt_learner" in name and param.requires_grad:
                params.append(param)
                names.append(name)
        return params, names

    def plot_tsne_of_prompts(self, step_id=None):
        """Randomly select 10 classes and plot t-SNE of their prompt vectors from priority queues."""
        available_classes = list(self.priority_queues.keys())
        if len(available_classes) < 1:
            print(f"[Step {step_id}] Skipping t-SNE plot: no priority queues available.")
            return

        selected_classes = random.sample(available_classes, min(10, len(available_classes)))

        all_vectors = []
        all_labels = []

        for class_id in selected_classes:
            pq = self.priority_queues[class_id]
            for _, _, prompt_vec in list(pq.queue):
                if isinstance(prompt_vec, torch.Tensor):
                    flattened_vec = prompt_vec.view(-1).cpu().numpy()
                else:
                    flattened_vec = prompt_vec.reshape(-1)
                all_vectors.append(flattened_vec)
                all_labels.append(class_id)

        if len(all_vectors) < 3:
            print(f"[Step {step_id}] Skipping t-SNE plot: not enough total vectors ({len(all_vectors)})")
            return

        perplexity_val = min(5, len(all_vectors) - 1)
        tsne = TSNE(n_components=2, perplexity=perplexity_val, random_state=42)
        reduced = tsne.fit_transform(np.array(all_vectors))

        plt.figure(figsize=(8, 6))
        scatter = plt.scatter(reduced[:, 0], reduced[:, 1], c=all_labels, cmap='tab10', s=40, alpha=0.8)
        plt.colorbar(scatter, ticks=selected_classes)
        plt.title(f"t-SNE of Prompt Vectors from 10 Random Classes (Step {step_id})")
        plt.xlabel("t-SNE Dimension 1")
        plt.ylabel("t-SNE Dimension 2")
        plt.grid(True)
        plt.tight_layout()

        save_name = f"tsne_prompt_step_{step_id}.png" if step_id is not None else "tsne_prompt.png"
        plt.savefig(save_name)
        print(f"[Step {step_id}] t-SNE plot saved to {save_name}")
        plt.close()
