import math
from typing import Optional, List
import torch
import torch.nn as nn
from transformers import FlavaModel

from lib.data.base import VLInputs
from lib.models.incremental_classifier import IncrementalClassifier

class FlavaEncoderForPrefixTuning(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.layers = self.model.encoder.layer
        
    def forward(self, 
        hidden_states: torch.Tensor, 
        attention_mask: Optional[torch.Tensor] = None,
        g_prompts: Optional[torch.Tensor] = None,
        g_layers: Optional[List[int]] = None,
        e_prompts: Optional[torch.Tensor] = None,
        e_layers: Optional[List[int]] = None
    ):        
        
        g_counter = -1
        e_counter = -1
        
        for idx, layer in enumerate(self.layers):
            
            current_attention_mask = attention_mask
            
            input = hidden_states
            
            hidden_states = layer.layernorm_before(hidden_states)
            
            q = layer.attention.attention.query(hidden_states)
            k = layer.attention.attention.key(hidden_states)
            v = layer.attention.attention.value(hidden_states)
            
            if (g_prompts is not None) and (idx in g_layers):
                g_counter += 1
                g = g_prompts[:, g_counter] # From [batch_size, n_layers, 2, n_prompts, d_model] To  [batch_size, 2, n_prompts, d_model]
                k = torch.cat([k, g[:, 0]], dim=1)
                v = torch.cat([v, g[:, 1]], dim=1)
                if current_attention_mask is not None:
                    current_attention_mask = torch.cat([
                        current_attention_mask,
                        torch.ones(k.size(0), g_prompts.size(-2), device=current_attention_mask.device)
                    ], dim=-1)
                
            if (e_prompts is not None) and (idx in e_layers):
                e_counter += 1
                e = e_prompts[:, e_counter]
                k = torch.cat([k, e[:, 0]], dim=1)
                v = torch.cat([v, e[:, 1]], dim=1)
                if current_attention_mask is not None:
                    current_attention_mask = torch.cat([
                        current_attention_mask,
                        torch.ones(k.size(0), e_prompts.size(-2), device=current_attention_mask.device)
                    ], dim=-1)

            q = layer.attention.attention.transpose_for_scores(q)
            k = layer.attention.attention.transpose_for_scores(k)
            v = layer.attention.attention.transpose_for_scores(v)
            
            attention_scores = torch.matmul(q, k.transpose(-1,-2))
            attention_scores = attention_scores / math.sqrt(int(self.model.config.hidden_size / self.model.config.num_attention_heads))
            
            if current_attention_mask is not None:
                extended_attention_mask = self.model.get_extended_attention_mask(current_attention_mask, hidden_states.size(), hidden_states.device)
                attention_scores += extended_attention_mask
            
            attention_probs = layer.attention.attention.dropout(torch.softmax(attention_scores, dim=-1))
            
            context_layer = torch.matmul(attention_probs, v)
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (layer.attention.attention.all_head_size,)
            context_layer = context_layer.view(*new_context_layer_shape)
            
            attention_output = layer.attention.output(context_layer, hidden_states)
            
            hidden_states = attention_output + input
            
            layer_output = layer.intermediate(layer.layernorm_after(hidden_states))
            hidden_states = layer.output(layer_output, hidden_states)
            
        return hidden_states

class FlavaTextModel(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.encoder = FlavaEncoderForPrefixTuning(self.model)
        
    def forward(self, 
        input_ids: torch.Tensor, 
        attention_mask: torch.Tensor,
        g_prompts: Optional[torch.Tensor] = None,
        g_layers: Optional[List[int]] = None,
        e_prompts: Optional[torch.Tensor] = None,
        e_layers: Optional[List[int]] = None
    ):
        hidden_states = self.model.embeddings(input_ids=input_ids)
        hidden_states = self.encoder(
            hidden_states=hidden_states, 
            attention_mask=attention_mask,
            g_prompts=g_prompts, g_layers=g_layers,
            e_prompts=e_prompts, e_layers=e_layers
        )
        return hidden_states
    
class FlavaImageModel(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.encoder = FlavaEncoderForPrefixTuning(self.model)
        
    def forward(self, 
        pixel_values: torch.Tensor,
        g_prompts: Optional[torch.Tensor] = None,
        g_layers: Optional[List[int]] = None,
        e_prompts: Optional[torch.Tensor] = None,
        e_layers: Optional[List[int]] = None
    ):
        hidden_states = self.model.embeddings(pixel_values=pixel_values)
        hidden_states = self.encoder(
            hidden_states=hidden_states, 
            g_prompts=g_prompts, g_layers=g_layers,
            e_prompts=e_prompts, e_layers=e_layers
        )
        return hidden_states
    
class FlavaMultimodalModel(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.encoder = FlavaEncoderForPrefixTuning(self.model)
        
    def forward(self, 
        hidden_states: torch.Tensor, 
        attention_mask: torch.Tensor,
        g_prompts: Optional[torch.Tensor] = None,
        g_layers: Optional[List[int]] = None,
        e_prompts: Optional[torch.Tensor] = None,
        e_layers: Optional[List[int]] = None
    ):
        cls_tokens = self.model.cls_token.expand(hidden_states.size(0), -1, -1)
        hidden_states = torch.cat([cls_tokens, hidden_states], dim=1)
        hidden_states = self.encoder(
            hidden_states=hidden_states, 
            attention_mask=attention_mask,
            g_prompts=g_prompts, g_layers=g_layers,
            e_prompts=e_prompts, e_layers=e_layers
        )
        hidden_states = self.model.layernorm(hidden_states)
        return hidden_states
    
class FlavaForDualPrompt(nn.Module):
    def __init__(self,
        n_g_prompts: int = 5,
        g_layers: List[int] = [0,1],
        n_e_prompts: int = 20,
        e_layers: List[int] = [2,3,4]
    ):
        super().__init__()
        self.n_g_prompts = n_g_prompts
        self.g_layers = g_layers
        self.n_e_prompts = n_e_prompts
        self.e_layers = e_layers
        
        self.flava = FlavaModel.from_pretrained("facebook/flava-full")
        for p in self.flava.parameters():
            p.requires_grad = False
        self.d_model = self.flava.config.hidden_size
        
        self.image_model = FlavaImageModel(self.flava.image_model)
        self.text_model = FlavaTextModel(self.flava.text_model)
        self.multimodal_model = FlavaMultimodalModel(self.flava.multimodal_model)
        
        self.image_to_mm_projection = self.flava.image_to_mm_projection
        self.text_to_mm_projection = self.flava.text_to_mm_projection
        
        # G-Prompt initialization
        self.g_prompts_text = nn.Parameter(torch.zeros(len(g_layers), 2, n_g_prompts, self.d_model))
        self.g_prompts_image = nn.Parameter(torch.zeros(len(g_layers), 2, n_g_prompts, self.d_model))
        self.g_prompts_multimodal = nn.Parameter(torch.zeros(len(g_layers), 2, n_g_prompts, self.d_model))
        nn.init.xavier_uniform_(self.g_prompts_text.data)
        nn.init.xavier_uniform_(self.g_prompts_image.data)
        nn.init.xavier_uniform_(self.g_prompts_multimodal.data)
        
        # E-Prompt Initialization
        self.e_prompts_text_pool = nn.ParameterList([])
        self.e_prompts_image_pool = nn.ParameterList([])
        self.e_prompts_multimodal_pool = nn.ParameterList([])
        self.e_key_pool = nn.ParameterList([])
        
        self.update_e_prompts()
        
    def forward(self, 
        batch: VLInputs, 
        experience_id: Optional[int] = None
    ):
        
        pixel_values=batch.pixel_values
        input_ids=batch.input_ids
        attention_mask=batch.attention_mask
        
        batch_size = input_ids.size(0)
        device = pixel_values.device
        
        # G-Prompt preparation
        g_prompts_text = self.g_prompts_text.repeat(batch_size, 1, 1, 1, 1).to(device)
        g_prompts_image = self.g_prompts_image.repeat(batch_size, 1, 1, 1, 1).to(device)
        g_prompts_multimodal = self.g_prompts_multimodal.repeat(batch_size, 1, 1, 1, 1).to(device)
        
        # E-Prompt preparation
        if self.training:
            e_prompts_text = self.e_prompts_text_pool[experience_id].repeat(batch_size, 1, 1, 1, 1).to(device)
            e_prompts_image = self.e_prompts_image_pool[experience_id].repeat(batch_size, 1, 1, 1, 1).to(device)
            e_prompts_multimodal = self.e_prompts_multimodal_pool[experience_id].repeat(batch_size, 1, 1, 1, 1).to(device)
            e_key = self.e_key_pool[experience_id]
        else:
            pos = self.get_dissimilarity_score(batch).argmin(dim=-1)
            e_prompts_text = torch.cat([p[None,:] for p in self.e_prompts_text_pool], dim=0)[pos]
            e_prompts_image = torch.cat([p[None,:] for p in self.e_prompts_image_pool], dim=0)[pos]
            e_prompts_multimodal = torch.cat([p[None,:] for p in self.e_prompts_multimodal_pool], dim=0)[pos]
            e_key = None
        
        # Vision path
        image_output = self.image_model(
            pixel_values=pixel_values,
            g_prompts=g_prompts_image, g_layers=self.g_layers,
            e_prompts=e_prompts_image, e_layers=self.e_layers
        )
        image_output = self.image_to_mm_projection(image_output)
        
        # Language path
        text_output = self.text_model(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            g_prompts=g_prompts_text, g_layers=self.g_layers,
            e_prompts=e_prompts_text, e_layers=self.e_layers
        )
        text_output = self.text_to_mm_projection(text_output)
        
        # Multimodal path
        hidden_states = torch.cat([image_output, text_output], dim=1)
        multimodal_attention_mask = torch.cat([torch.ones(hidden_states.size(0), image_output.size(1)+1, device=device), attention_mask], dim=-1)
        multimodal_output = self.multimodal_model(
            hidden_states=hidden_states, 
            attention_mask=multimodal_attention_mask,
            g_prompts=g_prompts_multimodal, g_layers=self.g_layers,
            e_prompts=e_prompts_multimodal, e_layers=self.e_layers
        )
        
        return multimodal_output, e_key
    
    @torch.no_grad()
    def get_cls_token(self, batch: VLInputs):
        return self.flava(
            pixel_values=batch.pixel_values,
            input_ids=batch.input_ids,
            attention_mask=batch.attention_mask
        ).multimodal_embeddings[:,0]
    
    def get_dissimilarity_score(self,
        batch: VLInputs, 
        key: Optional[torch.Tensor] = None
    ):
        cls_tokens = self.get_cls_token(batch)
        cls_tokens = cls_tokens/cls_tokens.norm(dim=-1, keepdim=True)
        
        if key is None:
            key = torch.cat([p for p in self.e_key_pool])
            
        key = key/key.norm(dim=-1, keepdim=True)
        
        return 1 - (cls_tokens @ key.T)
    
    def update_e_prompts(self):
        device = next(iter(self.parameters())).device
        e_prompts_text = nn.Parameter(torch.zeros(len(self.e_layers), 2, self.n_e_prompts, self.d_model, device=device))
        e_prompts_image = nn.Parameter(torch.zeros(len(self.e_layers), 2, self.n_e_prompts, self.d_model, device=device))
        e_prompts_multimodal = nn.Parameter(torch.zeros(len(self.e_layers), 2, self.n_e_prompts, self.d_model, device=device))
        nn.init.xavier_uniform_(e_prompts_text.data)
        nn.init.xavier_uniform_(e_prompts_image.data)
        nn.init.xavier_uniform_(e_prompts_multimodal.data)
        
        e_key = nn.Parameter(torch.zeros(1, self.d_model, device=device))
        nn.init.xavier_uniform_(e_key.data)
        
        self.e_prompts_text_pool.append(e_prompts_text)
        self.e_prompts_image_pool.append(e_prompts_image)
        self.e_prompts_multimodal_pool.append(e_prompts_multimodal)
        self.e_key_pool.append(e_key)
        
###

class FlavaDualPromptCL(nn.Module):
    def __init__(self,
        n_output_classes: int
    ):
        super().__init__()
        self.feature_extractor = FlavaForDualPrompt()
        self.incremental_classifier = IncrementalClassifier(self.feature_extractor.d_model, n_output_classes)
        
    def forward(self, inputs, experience_id=None):
        hidden_states, key = self.feature_extractor(inputs, experience_id)
        logits = self.incremental_classifier(hidden_states[:, 0])
        return logits, key
        
    def adaptation(self, n_output_classes: int):
        self.incremental_classifier.adaptation(n_output_classes)
        self.feature_extractor.update_e_prompts()
