import math
from typing import Optional, List
import torch
import torch.nn as nn
from transformers import FlavaModel

from lib.data.base import VLInputs
from lib.models.incremental_classifier import IncrementalClassifier

class FlavaEncoderForPromptTuning(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.layers = self.model.encoder.layer
        
    def forward(self, 
        hidden_states: torch.Tensor, 
        prompts: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ):
        device = hidden_states.device
        
        dims = prompts.size() # [batch_size, n_prompts, prompt_dim, d_model]
        prompts = prompts.view(dims[0], -1, dims[-1]) # stack prompts [batch_size, n_prompts*prompt_dim, d_model]
        hidden_states = torch.cat([prompts, hidden_states], dim=1) 
        
        current_attention_mask = None
        if attention_mask is not None:
            current_attention_mask = torch.cat([torch.ones(hidden_states.size(0), prompts.size(1), device=device), attention_mask], dim=-1)
            self.current_attention_mask = current_attention_mask
            
        for idx, layer in enumerate(self.layers):
                        
            input = hidden_states
            
            hidden_states = layer.layernorm_before(hidden_states)
            
            q = layer.attention.attention.query(hidden_states)
            k = layer.attention.attention.key(hidden_states)
            v = layer.attention.attention.value(hidden_states)

            q = layer.attention.attention.transpose_for_scores(q)
            k = layer.attention.attention.transpose_for_scores(k)
            v = layer.attention.attention.transpose_for_scores(v)
            
            attention_scores = torch.matmul(q, k.transpose(-1,-2))
            attention_scores = attention_scores / math.sqrt(int(self.model.config.hidden_size / self.model.config.num_attention_heads))
            
            if current_attention_mask is not None:
                extended_attention_mask = self.model.get_extended_attention_mask(current_attention_mask, hidden_states.size(), hidden_states.device)
                attention_scores += extended_attention_mask
            
            attention_probs = layer.attention.attention.dropout(torch.softmax(attention_scores, dim=-1))
            
            context_layer = torch.matmul(attention_probs, v)
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (layer.attention.attention.all_head_size,)
            context_layer = context_layer.view(*new_context_layer_shape)
            
            attention_output = layer.attention.output(context_layer, hidden_states)
            
            hidden_states = attention_output + input
            
            layer_output = layer.intermediate(layer.layernorm_after(hidden_states))
            hidden_states = layer.output(layer_output, hidden_states)
            
        return hidden_states
    
class FlavaTextModel(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.encoder = FlavaEncoderForPromptTuning(self.model)
        
    def forward(self, 
        input_ids: torch.Tensor, 
        attention_mask: torch.Tensor,
        prompts: torch.Tensor,
    ):
        hidden_states = self.model.embeddings(input_ids=input_ids)
        hidden_states = self.encoder(
            hidden_states=hidden_states, 
            attention_mask=attention_mask,
            prompts=prompts
        )
        self.current_attention_mask = self.encoder.current_attention_mask
        return hidden_states
    
class FlavaImageModel(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.encoder = FlavaEncoderForPromptTuning(self.model)
        
    def forward(self, 
        pixel_values: torch.Tensor,
        prompts: torch.Tensor
    ):
        hidden_states = self.model.embeddings(pixel_values=pixel_values)
        hidden_states = self.encoder(
            hidden_states=hidden_states, 
            prompts=prompts
        )
        return hidden_states
    
class FlavaMultimodalModel(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.encoder = FlavaEncoderForPromptTuning(self.model)
        
    def forward(self, 
        hidden_states: torch.Tensor, 
        attention_mask: torch.Tensor,
        prompts: torch.Tensor
    ):
        cls_tokens = self.model.cls_token.expand(hidden_states.size(0), -1, -1)
        hidden_states = torch.cat([cls_tokens, hidden_states], dim=1)
        hidden_states = self.encoder(
            hidden_states=hidden_states, 
            attention_mask=attention_mask,
            prompts=prompts
        )
        hidden_states = self.model.layernorm(hidden_states)
        return hidden_states
    
class FlavaForLearningToPrompt(nn.Module):
    def __init__(self,
        prompts_pool_dim: int = 40,
        n_prompts: int = 4,
        prompt_dim: int = 5
    ):
        super().__init__()
        self.prompts_pool_dim = prompts_pool_dim
        self.n_prompts = n_prompts
        self.prompt_dim = prompt_dim
        
        self.flava = FlavaModel.from_pretrained("facebook/flava-full")
        for p in self.flava.parameters():
            p.requires_grad = False
        self.d_model = self.flava.config.hidden_size
        
        self.image_model = FlavaImageModel(self.flava.image_model)
        self.text_model = FlavaTextModel(self.flava.text_model)
        self.multimodal_model = FlavaMultimodalModel(self.flava.multimodal_model)
        
        self.image_to_mm_projection = self.flava.image_to_mm_projection
        self.text_to_mm_projection = self.flava.text_to_mm_projection
        
        # init keys and prompts 
        self.prompts_pool_image = nn.Parameter(torch.zeros(self.prompts_pool_dim, self.prompt_dim, self.d_model))
        self.prompts_pool_text = nn.Parameter(torch.zeros(self.prompts_pool_dim, self.prompt_dim, self.d_model))
        self.prompts_pool_multimodal = nn.Parameter(torch.zeros(self.prompts_pool_dim, self.prompt_dim, self.d_model))
        self.keys_pool = nn.Parameter(torch.zeros(self.prompts_pool_dim, self.d_model))
        nn.init.xavier_uniform_(self.prompts_pool_image.data)
        nn.init.xavier_uniform_(self.prompts_pool_text.data)
        nn.init.xavier_uniform_(self.prompts_pool_multimodal.data)
        nn.init.xavier_uniform_(self.keys_pool.data)
        
        self.keys_frequency = torch.zeros(self.prompts_pool_dim)
        self.experience_keys_frequency = torch.ones(self.prompts_pool_dim)
        
    def forward(self, batch: VLInputs):
        pixel_values=batch.pixel_values
        input_ids=batch.input_ids
        attention_mask=batch.attention_mask
        batch_size = pixel_values.size(0)
        device = pixel_values.device
        
        scores = self.get_dissimilarity_score(batch)
        if self.training:
            scores *= self.experience_keys_frequency.to(device)
        top_scores = scores.sort().values[:,:self.n_prompts].contiguous()
        top_scores_pos = scores.sort().indices[:,:self.n_prompts].contiguous()
        self.keys_frequency += top_scores_pos.view(-1).bincount(minlength=self.prompts_pool_dim).to(self.keys_frequency.device)
        
        prompts_image = torch.cat([self.prompts_pool_image[top_scores_pos[i]][None,:] for i in range(batch_size)], dim=0)
        prompts_text = torch.cat([self.prompts_pool_text[top_scores_pos[i]][None,:] for i in range(batch_size)], dim=0)
        prompts_multimodal = torch.cat([self.prompts_pool_multimodal[top_scores_pos[i]][None,:] for i in range(batch_size)], dim=0)
        
        # Vision path
        image_output = self.image_model(
            pixel_values=pixel_values,
            prompts=prompts_image
        )
        image_output = self.image_to_mm_projection(image_output)
        
        # Language path
        text_output = self.text_model(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            prompts=prompts_text
        )
        text_output = self.text_to_mm_projection(text_output)
        
        # Multimodal path
        hidden_states = torch.cat([image_output, text_output], dim=1)
        attention_mask = self.text_model.current_attention_mask
        multimodal_attention_mask = torch.cat([torch.ones(hidden_states.size(0), image_output.size(1)+1, device=device), attention_mask], dim=-1)
        
        self.attention_mask = attention_mask
        self.multimodal_attention_mask = multimodal_attention_mask
        self.hidden_states = hidden_states
        
        multimodal_output = self.multimodal_model(
            hidden_states=hidden_states, 
            attention_mask=multimodal_attention_mask,
            prompts=prompts_multimodal
        )
        
        return multimodal_output, top_scores.view(-1)
    
    def update_key_frequency(self):
        self.experience_keys_frequency = self.keys_frequency/self.keys_frequency.sum()
    
    @torch.no_grad()
    def get_cls_token(self, batch: VLInputs):
        return self.flava(
            pixel_values=batch.pixel_values,
            input_ids=batch.input_ids,
            attention_mask=batch.attention_mask
        ).multimodal_embeddings[:,0]
    
    def get_dissimilarity_score(self, batch: VLInputs):
        cls_tokens = self.get_cls_token(batch)
        cls_tokens = cls_tokens/cls_tokens.norm(dim=-1, keepdim=True)            
        keys = self.keys_pool/self.keys_pool.norm(dim=-1, keepdim=True)
        return 1 - (cls_tokens @ keys.T)
    
class FlavaLearningToPromptCL(nn.Module):
    def __init__(self,
        n_output_classes: int
    ):
        super().__init__()
        self.feature_extractor = FlavaForLearningToPrompt()
        self.incremental_classifier = IncrementalClassifier(self.feature_extractor.d_model, n_output_classes)
        
    def forward(self, inputs):
        hidden_states, top_scores = self.feature_extractor(inputs)
        n = self.feature_extractor.n_prompts*self.feature_extractor.prompt_dim + 1
        logits = self.incremental_classifier(hidden_states[:, :n]).mean(dim=1)
        return logits, top_scores
        
    def adaptation(self, n_output_classes: int):
        self.incremental_classifier.adaptation(n_output_classes)
        self.feature_extractor.update_key_frequency()