import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from .finch import FINCH
from .modeling_t5_MAD import T5ForConditionalGeneration
from .vit import VisionTransformer
from transformers import T5Tokenizer
from transformers.modeling_outputs import BaseModelOutput

def _get_tokenizer(tokenizer_path, num_bins=0):
    """Loads a T5 tokenizer and adds special time tokens."""
    if 't5' in tokenizer_path:
        tokenizer = T5Tokenizer.from_pretrained(tokenizer_path, local_files_only=True)
        if num_bins:
            new_tokens = [f"<time={i}>" for i in range(num_bins)]
            tokenizer.add_tokens(list(new_tokens))
    else:
        raise NotImplementedError(tokenizer_path)
    return tokenizer

class LOCO(torch.nn.Module):

    def __init__(self,
                 t5_path,
                 num_features=100,
                 embed_dim=768,
                 depth=12,
                 heads=12,
                 mlp_dim=2048,
                 vis_drop=0.,
                 tokenizer=None,
                 enc_drop=0.,
                 dec_drop=0.1,
                 use_speech=True,
                 use_video=True,
                 num_bins=100,
                 label_smoothing=0.1,
                 args=None):
        super().__init__()
        self.args = args
        self.t5_model = T5ForConditionalGeneration.from_pretrained(
            pretrained_model_name_or_path=t5_path,
            encoder_dropout=enc_drop,
            decoder_dropout=dec_drop,
            label_smoothing=label_smoothing,
            local_files_only=True,
            is_gated_act="v1_1" in t5_path
        )
        self.t5_model.resize_token_embeddings(len(tokenizer))
        
        self.visual_encoder = VisionTransformer(
            num_features=num_features,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=heads,
            mlp_dim=mlp_dim,
            qkv_bias=True,
            qk_scale=None,
            drop_rate=vis_drop,
            attn_drop_rate=vis_drop,
            norm_layer=nn.LayerNorm
        )
        self.t5_tokenizer = tokenizer
        self.use_speech = use_speech
        self.use_video = use_video
        self.embed_dim = embed_dim
        
        # Project visual features to T5's model dimension if they differ
        self.proj_v2t = None
        if self.t5_model.model_dim != embed_dim:
            self.proj_v2t = nn.Linear(embed_dim, self.t5_model.model_dim)

        # Learnable memory tokens for writing context to be passed to the next window
        self.learnable_memory_token_video = nn.Parameter(torch.zeros(1, self.args.mem_size, embed_dim))
        self.learnable_memory_token_speech = nn.Parameter(torch.zeros(1, self.args.mem_size, embed_dim))
        
        # Placeholders for reading memory tokens from the previous window
        self.memory_token_video = None
        self.memory_token_speech = None

        # Placeholders for accumulating features across windows for global summary
        self.accum_tokens_video = [] if self.args.accum else None
        self.accum_tokens_speech = [] if self.args.accum else None
        
        self.global_flag = False
        self.window_number = 0
        
        # A simple attention block used for slot attention mechanism
        self.simple_attention_block = self.simple_attention

    def reset_memory_tokens(self):
        """Resets memory tokens at the beginning of each new video."""
        self.memory_token_video = None
        self.memory_token_speech = None

    def reset_accm_memory_tokens(self):
        """Resets accumulated tokens for global context."""
        self.accum_tokens_video = []
        self.accum_tokens_speech = []
        self.global_flag = False

    def gradient_off_memory_tokens(self):
        """Detaches memory tokens from the computation graph to prevent gradients from flowing back."""
        if self.memory_token_video is not None:
            self.memory_token_video = self.memory_token_video.cpu().detach()
        if self.memory_token_speech is not None:
            self.memory_token_speech = self.memory_token_speech.cpu().detach()
    
    def gradient_on_memory_tokens(self):
        """Re-attaches memory tokens to the computation graph."""
        if self.memory_token_video is not None:
            self.memory_token_video = self.memory_token_video.to('cuda')
        if self.memory_token_speech is not None:
            self.memory_token_speech = self.memory_token_speech.to('cuda')
            
    def finch_cluster_attention(self, video: torch.Tensor, write_token: torch.Tensor):
        """
        Aggregates video features using FINCH clustering and updates the write token via attention.
        
        Args:
            video (torch.Tensor): Video features for the current window. Shape: (1, T, D).
            write_token (torch.Tensor): Learnable memory token for writing context. Shape: (1, mem_size, D).
            
        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                - Original video tensor (unchanged).
                - Updated write token.
                - Averaged cluster embeddings.
        """
        B, T, feat_dim = video.shape
        assert B == 1, "FINCH clustering currently supports batch_size=1."

        # 1. Perform FINCH clustering on video features
        data_2d = video.squeeze(0)
        c, num_clust_list, _ = FINCH(data_2d, distance='cosine', tw_finch=True, verbose=False)
        
        partition_idx = 1
        if partition_idx >= c.shape[1]:
            partition_idx = c.shape[1] - 1
        cluster_labels = c[:, partition_idx]
        n_clusters = int(num_clust_list[partition_idx])

        # 2. Compute mean embedding for each cluster
        cluster_embs_list = []
        for clus_id in range(n_clusters):
            mask = (cluster_labels == clus_id)
            if mask.sum() > 0:
                cluster_embs_list.append(data_2d[mask].mean(dim=0))
            else:
                cluster_embs_list.append(torch.zeros(feat_dim, device=video.device))
        cluster_embs_2d = torch.stack(cluster_embs_list, dim=0)
        
        # 3. Update the write token using attention over cluster embeddings
        cluster_embs = cluster_embs_2d.unsqueeze(0)
        new_write_token = self.simple_attention(write_token, cluster_embs)
        
        if self.args.frame_attn:
            new_write_token = self.simple_attention(write_token, video)
            
        return video, new_write_token, cluster_embs
    
    def simple_attention(self, query, context):
        """A simple scaled dot-product attention mechanism."""
        B, q_len, d = query.shape
        scale = math.sqrt(d)
        attn_scores = torch.matmul(query, context.transpose(-1, -2)) / scale
        attn_probs = F.softmax(attn_scores, dim=-1)
        out = torch.matmul(attn_probs, context)
        return out

    def forward(self, video, input_tokenized, output_tokenized):
        # Process video features
        video_dict = None
        if self.use_video:
            if isinstance(video, dict): # Handle cached features
                video, atts_vis = video["video"], video["atts_vis"]
            else:
                # Prepend memory tokens to the video features
                if 'visual' in self.args.window_memory:
                    batch_size, _, feat_dim = video.shape
                    read_token = self.memory_token_video if self.memory_token_video is not None else torch.zeros((batch_size, self.args.mem_size, feat_dim), device=video.device)
                    write_token = self.learnable_memory_token_video.expand(batch_size, -1, -1)
                    
                    cluster_embs_avg = None
                    if self.args.finch:
                        video, write_token, cluster_embs_avg = self.finch_cluster_attention(video, write_token)
                    
                    video = torch.cat([read_token, video, write_token], dim=1)
                
                # Handle global accumulation at the last window
                if self.global_flag and self.accum_tokens_video is not None:
                    if self.args.finch_accum and self.args.slot_attention:
                        B2, total_len, D2 = video.shape
                        frame_size = total_len // self.window_number
                        for i, clust_embs_3d in enumerate(self.accum_tokens_video):
                            start_idx = i * frame_size
                            end_idx = (i + 1) * frame_size
                            if end_idx > total_len: break
                            slot_query = video[:, start_idx:end_idx, :]
                            updated_slot = self.simple_attention_block(slot_query, clust_embs_3d)
                            video = torch.cat([video[:, :start_idx, :], updated_slot, video[:, end_idx:, :]], dim=1)
                    else:
                         accum_video = torch.cat(self.accum_tokens_video, dim=1) if self.accum_tokens_video else None
                         if accum_video is not None:
                            video = torch.cat([accum_video, video], dim=1)

                video = self.visual_encoder(video)
                if self.proj_v2t is not None:
                    video = self.proj_v2t(video)
                
                atts_vis = torch.ones(video.size()[:-1], dtype=torch.long).to(video.device)
            video_dict = {"video": video, "atts_vis": atts_vis}
        
        # Process speech features
        text_output = None
        if self.use_speech:
            text = self.t5_model.encoder.embed_tokens(input_tokenized['input_ids'])
            if 'speech' in self.args.window_memory:
                batch_size, seq_len, hidden_dim = text.shape
                read_token = self.memory_token_speech if self.memory_token_speech is not None else torch.zeros((batch_size, self.args.mem_size, hidden_dim), device=text.device)
                write_token = self.learnable_memory_token_speech.expand(batch_size, -1, -1)
                text = torch.cat([read_token, text, write_token], dim=1)

                mem_mask = torch.ones((batch_size, 2 * self.args.mem_size), dtype=torch.long, device=text.device)
                input_tokenized['attention_mask'] = torch.cat([mem_mask, input_tokenized['attention_mask']], dim=1)

                if self.global_flag and self.accum_tokens_speech is not None:
                    accum_speech = torch.cat(self.accum_tokens_speech, dim=1) if self.accum_tokens_speech else None
                    if accum_speech is not None:
                        text = torch.cat([accum_speech, text], dim=1)
                        accum_mask = torch.ones((text.size(0), accum_speech.size(1)), dtype=torch.long, device=text.device)
                        input_tokenized['attention_mask'] = torch.cat([accum_mask, input_tokenized['attention_mask']], dim=1)
                        
            encoded_speech = self.t5_model.encoder(
                attention_mask=input_tokenized['attention_mask'],
                inputs_embeds=text,
            )
            text_output = encoded_speech.last_hidden_state

        # Combine multimodal features
        if self.use_video and self.use_speech:
            combined_features = torch.cat([video, text_output], dim=1)
            combined_atts = torch.cat([atts_vis, input_tokenized['attention_mask']], dim=1)
            encoder_outputs = BaseModelOutput(last_hidden_state=combined_features)
        elif self.use_video:
            encoder_outputs = BaseModelOutput(last_hidden_state=video)
            combined_atts = atts_vis
        elif self.use_speech:
            encoder_outputs = BaseModelOutput(last_hidden_state=text_output)
            combined_atts = input_tokenized['attention_mask']
        else:
            raise ValueError("At least one of use_video or use_speech must be True.")

        # T5 decoder forward pass
        targets = output_tokenized['input_ids'].masked_fill(
            output_tokenized['input_ids'] == self.t5_tokenizer.pad_token_id, -100
        )
        outputs = self.t5_model(
            encoder_outputs=encoder_outputs,
            attention_mask=combined_atts,
            decoder_attention_mask=output_tokenized['attention_mask'],
            return_dict=True,
            labels=targets,
        )
        
        # Update memory tokens for the next window
        if 'visual' in self.args.window_memory and self.use_video:
            self.memory_token_video = video[:, -self.args.mem_size:, :]
            if self.accum_tokens_video is not None:
                with torch.no_grad():
                    if self.args.finch_accum and self.args.finch and cluster_embs_avg is not None:
                        self.accum_tokens_video.append(cluster_embs_avg.detach())
                    else:
                        self.accum_tokens_video.append(self.memory_token_video.detach())

        if 'speech' in self.args.window_memory and self.use_speech:
            self.memory_token_speech = text_output[:, -self.args.mem_size:, :]
            if self.accum_tokens_speech is not None:
                with torch.no_grad():
                    self.accum_tokens_speech.append(self.memory_token_speech.detach())

        return {"loss": outputs.loss}, video_dict

    @torch.no_grad()
    def generate(
            self,
            video,
            input_tokenized,
            use_nucleus_sampling=False,
            num_beams=4,
            max_length=256,
            min_length=1,
            top_p=0.9,
            repetition_penalty=1.0,
            length_penalty=1.0,
            num_captions=1,
            temperature=1,
    ):
        """
        Generates text captions for a given video window. The logic mirrors the forward pass.
        """
        # Process video features
        if self.use_video:
            if 'visual' in self.args.window_memory:
                batch_size, _, feat_dim = video.shape
                read_token = self.memory_token_video if self.memory_token_video is not None else torch.zeros((batch_size, self.args.mem_size, feat_dim), device=video.device)
                write_token = self.learnable_memory_token_video.expand(batch_size, -1, -1)
                
                cluster_embs_avg = None
                if self.args.finch:
                    video, write_token, cluster_embs_avg = self.finch_cluster_attention(video, write_token)
                
                video = torch.cat([read_token, video, write_token], dim=1)
            
            if self.global_flag and self.accum_tokens_video is not None:
                if self.args.finch_accum and self.args.slot_attention:
                    B2, total_len, D2 = video.shape
                    frame_size = total_len // self.window_number
                    for i, clust_embs_3d in enumerate(self.accum_tokens_video):
                        start_idx, end_idx = i * frame_size, (i + 1) * frame_size
                        if end_idx > total_len: break
                        slot_query = video[:, start_idx:end_idx, :]
                        updated_slot = self.simple_attention_block(slot_query, clust_embs_3d)
                        video = torch.cat([video[:, :start_idx, :], updated_slot, video[:, end_idx:, :]], dim=1)
                else:
                    accum_video = torch.cat(self.accum_tokens_video, dim=1) if self.accum_tokens_video else None
                    if accum_video is not None:
                        video = torch.cat([accum_video, video], dim=1)

            video = self.visual_encoder(video)
            if self.proj_v2t is not None:
                video = self.proj_v2t(video)
            atts_vis = torch.ones(video.size()[:-1], dtype=torch.long).to(video.device)

        # Process speech features
        text_output = None
        if self.use_speech:
            text = self.t5_model.encoder.embed_tokens(input_tokenized['input_ids'])
            if 'speech' in self.args.window_memory:
                batch_size, _, hidden_dim = text.shape
                read_token = self.memory_token_speech if self.memory_token_speech is not None else torch.zeros((batch_size, self.args.mem_size, hidden_dim), device=text.device)
                write_token = self.learnable_memory_token_speech.expand(batch_size, -1, -1)
                text = torch.cat([read_token, text, write_token], dim=1)

                mem_mask = torch.ones((batch_size, 2 * self.args.mem_size), dtype=torch.long, device=text.device)
                input_tokenized['attention_mask'] = torch.cat([mem_mask, input_tokenized['attention_mask']], dim=1)

                if self.global_flag and self.accum_tokens_speech is not None:
                    accum_speech = torch.cat(self.accum_tokens_speech, dim=1) if self.accum_tokens_speech else None
                    if accum_speech is not None:
                        text = torch.cat([accum_speech, text], dim=1)
                        accum_mask = torch.ones((text.size(0), accum_speech.size(1)), dtype=torch.long, device=text.device)
                        input_tokenized['attention_mask'] = torch.cat([accum_mask, input_tokenized['attention_mask']], dim=1)
            
            encoded_speech = self.t5_model.encoder(
                attention_mask=input_tokenized['attention_mask'],
                inputs_embeds=text,
            )
            text_output = encoded_speech.last_hidden_state

        # Combine multimodal features
        if self.use_video and self.use_speech:
            combined_features = torch.cat([video, text_output], dim=1)
            combined_atts = torch.cat([atts_vis, input_tokenized['attention_mask']], dim=1)
            encoder_outputs = BaseModelOutput(last_hidden_state=combined_features)
        elif self.use_video:
            encoder_outputs = BaseModelOutput(last_hidden_state=video)
            combined_atts = atts_vis
        elif self.use_speech:
            encoder_outputs = BaseModelOutput(last_hidden_state=text_output)
            combined_atts = input_tokenized['attention_mask']
        
        # Generate text using the T5 decoder
        outputs = self.t5_model.generate(
            encoder_outputs=encoder_outputs,
            attention_mask=combined_atts,
            do_sample=use_nucleus_sampling,
            top_p=top_p,
            temperature=temperature,
            num_beams=num_beams,
            max_new_tokens=max_length,
            min_length=min_length,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            num_return_sequences=num_captions,
        )

        output_text = self.t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # Update memory tokens
        if 'visual' in self.args.window_memory and self.use_video:
            self.memory_token_video = video[:, -self.args.mem_size:, :]
            if self.accum_tokens_video is not None:
                if self.args.finch_accum and self.args.finch and cluster_embs_avg is not None:
                    self.accum_tokens_video.append(cluster_embs_avg.detach())
                else:
                    self.accum_tokens_video.append(self.memory_token_video.detach())

        if 'speech' in self.args.window_memory and self.use_speech:
            self.memory_token_speech = text_output[:, -self.args.mem_size:, :]
            if self.accum_tokens_speech is not None:
                self.accum_tokens_speech.append(self.memory_token_speech.detach())
        
        return output_text
