import math
from typing import Optional, List
import torch
import torch.nn as nn
from transformers import FlavaModel
from tqdm import tqdm

from lib.data.base import VLInputs
from lib.models.incremental_classifier import IncrementalClassifier

from lib.models.dual_prompt import FlavaImageModel, FlavaTextModel, FlavaMultimodalModel
    
class FlavaForDualPrompt(nn.Module):
    def __init__(self,
        n_g_prompts: int = 5,
        g_layers: List[int] = [0,1],
        n_e_prompts: int = 20,
        e_layers: List[int] = [2,3,4]
    ):
        super().__init__()
        self.n_g_prompts = n_g_prompts
        self.g_layers = g_layers
        self.n_e_prompts = n_e_prompts
        self.e_layers = e_layers
        
        self.flava = FlavaModel.from_pretrained("facebook/flava-full")
        for p in self.flava.parameters():
            p.requires_grad = False
        self.d_model = self.flava.config.hidden_size
        
        self.image_model = FlavaImageModel(self.flava.image_model)
        self.text_model = FlavaTextModel(self.flava.text_model)
        self.multimodal_model = FlavaMultimodalModel(self.flava.multimodal_model)
        
        self.image_to_mm_projection = self.flava.image_to_mm_projection
        self.text_to_mm_projection = self.flava.text_to_mm_projection
        
        # G-Prompt initialization
        self.g_prompts_text = nn.Parameter(torch.zeros(len(g_layers), 2, n_g_prompts, self.d_model))
        self.g_prompts_image = nn.Parameter(torch.zeros(len(g_layers), 2, n_g_prompts, self.d_model))
        self.g_prompts_multimodal = nn.Parameter(torch.zeros(len(g_layers), 2, n_g_prompts, self.d_model))
        nn.init.xavier_uniform_(self.g_prompts_text.data)
        nn.init.xavier_uniform_(self.g_prompts_image.data)
        nn.init.xavier_uniform_(self.g_prompts_multimodal.data)
        
        # E-Prompt Initialization
        self.e_prompts_text_pool = nn.ParameterList([])
        self.e_prompts_image_pool = nn.ParameterList([])
        self.e_prompts_multimodal_pool = nn.ParameterList([])
        self.e_key_pool = torch.tensor([])
        
        self.update_e_prompts()
        
    def forward(self, 
        batch: VLInputs, 
        experience_id: Optional[int] = None
    ):
        
        pixel_values=batch.pixel_values
        input_ids=batch.input_ids
        attention_mask=batch.attention_mask
        
        batch_size = input_ids.size(0)
        device = pixel_values.device
        
        # G-Prompt preparation
        g_prompts_text = self.g_prompts_text.repeat(batch_size, 1, 1, 1, 1).to(device)
        g_prompts_image = self.g_prompts_image.repeat(batch_size, 1, 1, 1, 1).to(device)
        g_prompts_multimodal = self.g_prompts_multimodal.repeat(batch_size, 1, 1, 1, 1).to(device)
        
        # E-Prompt preparation
        if self.training:
            e_prompts_text = self.e_prompts_text_pool[experience_id].repeat(batch_size, 1, 1, 1, 1).to(device)
            e_prompts_image = self.e_prompts_image_pool[experience_id].repeat(batch_size, 1, 1, 1, 1).to(device)
            e_prompts_multimodal = self.e_prompts_multimodal_pool[experience_id].repeat(batch_size, 1, 1, 1, 1).to(device)
        else:
            pos = self.get_dissimilarity_score(batch).argmin(dim=-1)
            e_prompts_text = torch.cat([p[None,:] for p in self.e_prompts_text_pool], dim=0)[pos]
            e_prompts_image = torch.cat([p[None,:] for p in self.e_prompts_image_pool], dim=0)[pos]
            e_prompts_multimodal = torch.cat([p[None,:] for p in self.e_prompts_multimodal_pool], dim=0)[pos]
        
        # Vision path
        image_output = self.image_model(
            pixel_values=pixel_values,
            g_prompts=g_prompts_image, g_layers=self.g_layers,
            e_prompts=e_prompts_image, e_layers=self.e_layers
        )
        image_output = self.image_to_mm_projection(image_output)
        
        # Language path
        text_output = self.text_model(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            g_prompts=g_prompts_text, g_layers=self.g_layers,
            e_prompts=e_prompts_text, e_layers=self.e_layers
        )
        text_output = self.text_to_mm_projection(text_output)
        
        # Multimodal path
        hidden_states = torch.cat([image_output, text_output], dim=1)
        multimodal_attention_mask = torch.cat([torch.ones(hidden_states.size(0), image_output.size(1)+1, device=device), attention_mask], dim=-1)
        multimodal_output = self.multimodal_model(
            hidden_states=hidden_states, 
            attention_mask=multimodal_attention_mask,
            g_prompts=g_prompts_multimodal, g_layers=self.g_layers,
            e_prompts=e_prompts_multimodal, e_layers=self.e_layers
        )
        
        return multimodal_output
    
    @torch.no_grad()
    def get_cls_token(self, batch: VLInputs):
        return self.flava(
            pixel_values=batch.pixel_values,
            input_ids=batch.input_ids,
            attention_mask=batch.attention_mask
        ).multimodal_embeddings[:,0]
        
    @torch.no_grad()
    def compute_key_closed_form(self, dataloader: torch.utils.data.DataLoader):
        device = next(iter(self.parameters())).device
        cls_tokens_all = []
        with tqdm(dataloader, unit="batch") as tepoch:
            for inputs, _ in tepoch:
                tepoch.set_description(f"Compute E-Key")
                inputs.to(device)
                cls_tokens_all += [self.get_cls_token(inputs)]
        cls_tokens_all = torch.cat(cls_tokens_all, dim=0)
        mean_cls_token = cls_tokens_all.mean(dim=0)
        key = mean_cls_token/mean_cls_token.norm()
        return key          
        
    def update_e_key(self, dataloader: torch.utils.data.DataLoader):
        key = self.compute_key_closed_form(dataloader)
        self.e_key_pool = torch.cat([self.e_key_pool.to(key.device), key[None,:]], dim=0)
        
    def get_dissimilarity_score(self, batch: VLInputs):
        cls_tokens = self.get_cls_token(batch)
        cls_tokens = cls_tokens/cls_tokens.norm(dim=-1, keepdim=True)        
        return 1 - (cls_tokens @ self.e_key_pool.T)
    
    
    def update_e_prompts(self):
        device = next(iter(self.parameters())).device
        e_prompts_text = nn.Parameter(torch.zeros(len(self.e_layers), 2, self.n_e_prompts, self.d_model, device=device))
        e_prompts_image = nn.Parameter(torch.zeros(len(self.e_layers), 2, self.n_e_prompts, self.d_model, device=device))
        e_prompts_multimodal = nn.Parameter(torch.zeros(len(self.e_layers), 2, self.n_e_prompts, self.d_model, device=device))
        nn.init.xavier_uniform_(e_prompts_text.data)
        nn.init.xavier_uniform_(e_prompts_image.data)
        nn.init.xavier_uniform_(e_prompts_multimodal.data)
        
        self.e_prompts_text_pool.append(e_prompts_text)
        self.e_prompts_image_pool.append(e_prompts_image)
        self.e_prompts_multimodal_pool.append(e_prompts_multimodal)
        
###

class FlavaDualPromptCL(nn.Module):
    def __init__(self,
        n_output_classes: int
    ):
        super().__init__()
        self.feature_extractor = FlavaForDualPrompt()
        self.incremental_classifier = IncrementalClassifier(self.feature_extractor.d_model, n_output_classes)
        
    def forward(self, inputs, experience_id=None):
        hidden_states = self.feature_extractor(inputs, experience_id)
        logits = self.incremental_classifier(hidden_states[:, 0])
        return logits
        
    def adaptation(self, n_output_classes: int):
        self.incremental_classifier.adaptation(n_output_classes)
        self.feature_extractor.update_e_prompts()
