"""Custom data collators for token classification training."""

import torch
from typing import Dict, List, Any, Tuple, Optional

from torch.nn.attention.flex_attention import BlockMask, create_block_mask


class HiddenStateDataCollator:
    """
    Data collator that processes input_ids through a base model to generate hidden states,
    then collates them for training the classification model.
    """
    
    def __init__(
        self,
        base_model,
        lm_head,
        tokenizer,
        args,
        training_device=None,
        precomputed_train_data=None,
        precomputed_eval_data=None,
    ):
        self.base_model = base_model
        self.lm_head = lm_head
        self.tokenizer = tokenizer
        self.args = args
        self.training_device = training_device
        self.precomputed_train_data = precomputed_train_data
        self.precomputed_eval_data = precomputed_eval_data
        
        # Extract commonly used config values for convenience
        self.padding_multiple = 128
        self.max_length = getattr(self.tokenizer, 'model_max_length', None)
        self.label_pad_token_id = -100
        self.last_rollout_only = args.last_rollout_only
        self.last_label_only = args.last_label_only
        self.last_token_only = args.last_token_only
        self.base_batch_size = args.base_batch_size
        if self.base_batch_size is None:
            self.base_batch_size = args.batch_size
        
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Check if precomputed hidden states are available
        use_precomputed = 'train' in features[0]

        batch_size = len(features)

        # Extract input_ids, labels, and bin_idx - SAME LOGIC FOR BOTH PATHS
        input_ids = [f["input_ids"] for f in features]
        labels = [f["labels"] for f in features]
        bin_idx = [f["bin_idx"] for f in features]
        section_labels = [f["section_labels"] for f in features]

        # Find the maximum length in the batch
        max_length_in_batch = max(len(seq) for seq in input_ids)

        # Apply global max_length limit if specified
        if self.max_length is not None:
            max_length_in_batch = min(max_length_in_batch, self.max_length)

        # Round up to the nearest multiple of padding_multiple
        padded_length = ((max_length_in_batch // self.padding_multiple) + 1) * self.padding_multiple

        # Prepare input for base model
        def truncate_and_pad(seq, pad_value):
            truncated = seq[:padded_length]
            return truncated + [pad_value] * (padded_length - len(truncated))

        padded_input_ids = [truncate_and_pad(seq, self.tokenizer.pad_token_id) for seq in input_ids]
        attention_mask = [[1] * min(len(seq), padded_length) + [0] * max(0, padded_length - len(seq)) for seq in input_ids]
        padded_labels = [truncate_and_pad(seq, self.label_pad_token_id) for seq in labels]
        padded_bin_idx = [truncate_and_pad(seq, self.label_pad_token_id) for seq in bin_idx]
        padded_section_labels = [truncate_and_pad(seq, -100) for seq in section_labels] if section_labels[0] is not None else []

        # Convert to tensors - SAME LOGIC FOR BOTH PATHS
        input_ids_tensor = torch.tensor(padded_input_ids, dtype=torch.long)
        attention_mask_tensor = torch.tensor(attention_mask, dtype=torch.long)
        position_ids_tensor = self._generate_position_ids(padded_section_labels)
        section_mask = self._generate_special_attention_mask(padded_section_labels)

        if not use_precomputed:
            # Move to base model's device
            input_ids_tensor = input_ids_tensor.to(self.base_model.device)
            section_mask = section_mask.to(self.base_model.device)
            position_ids_tensor = position_ids_tensor.to(self.base_model.device)

            # Generate hidden states using base model in chunks
            all_hidden_states = []

            with torch.no_grad():
                self.base_model.eval()
                for i in range(0, batch_size, self.base_batch_size):
                    end_idx = min(i + self.base_batch_size, batch_size)
                    # Create block mask for this specific chunk
                    chunk_section_mask = section_mask[i:end_idx]
                    chunk_block_mask = create_block_mask_for_base_model(chunk_section_mask, self.base_model.device)

                    chunk_kwargs = {
                        "input_ids": input_ids_tensor[i:end_idx],
                        "attention_mask": chunk_block_mask,
                        "position_ids": position_ids_tensor[i:end_idx],
                        "return_dict": True,
                    }
                    chunk_outputs = self.base_model(**chunk_kwargs)
                    all_hidden_states.append(chunk_outputs.last_hidden_state)

            hidden_states = torch.cat(all_hidden_states, dim=0).detach() # [B, T, D]
            # Move to training device
            hidden_states = hidden_states.to(self.training_device)

        # Move all other tensors to training device
        attention_mask_tensor = attention_mask_tensor.to(self.training_device)
        labels_tensor = torch.tensor(padded_labels, dtype=torch.long, device=self.training_device)
        bin_idx_tensor = torch.tensor(padded_bin_idx, dtype=torch.long, device=self.training_device)
        section_labels_tensor = torch.tensor(padded_section_labels, dtype=torch.long, device=self.training_device)
        position_ids_tensor = position_ids_tensor.to(self.training_device)
        section_mask_tensor = section_mask.to(self.training_device)

        # Apply filtering logic based on last_rollout_only setting
        filter = section_labels_tensor == 4
        if self.last_token_only:
            for j in range(len(labels_tensor)):
                section_labels_mask = section_labels_tensor[j] == 4
                filter[j, :-1] = filter[j, :-1] & ~section_labels_mask[1:]
        if self.last_rollout_only:
            for j in range(len(labels_tensor)):
                last_section_j = section_mask_tensor[j].max().item()
                filter[j] = filter[j] & (section_mask_tensor[j] == last_section_j)

        # Filter and pad to longest filtered sequence
        if use_precomputed:
            # For precomputed, we only filter the non-hidden-state tensors first
            filtered_data = [(attention_mask_tensor[i][filter[i]],
                             labels_tensor[i][filter[i]], bin_idx_tensor[i][filter[i]],
                             section_mask_tensor[i][filter[i]], position_ids_tensor[i][filter[i]],
                             section_labels_tensor[i][filter[i]])
                            for i in range(len(labels_tensor))]
        else:
            # For regular, filter everything including hidden states
            filtered_data = [(hidden_states[i][filter[i]], attention_mask_tensor[i][filter[i]],
                             labels_tensor[i][filter[i]], bin_idx_tensor[i][filter[i]],
                             section_mask_tensor[i][filter[i]], position_ids_tensor[i][filter[i]],
                             section_labels_tensor[i][filter[i]])
                            for i in range(len(hidden_states))]

        if self.last_label_only:
            for i in range(len(filtered_data)):
                if use_precomputed:
                    filtered_data[i][1][:-1] = self.label_pad_token_id  # labels at index 1
                else:
                    filtered_data[i][2][:-1] = self.label_pad_token_id  # labels at index 2

        if use_precomputed:
            max_len = max(len(seq[0]) for seq in filtered_data)  # attention_mask at index 0
        else:
            max_len = max(len(seq[0]) for seq in filtered_data)  # hidden_states at index 0

        def pad_tensor(tensor, target_len, pad_value=0):
            pad_len = target_len - len(tensor)
            if pad_len > 0:
                pad_shape = (pad_len,) + tensor.shape[1:] if tensor.dim() > 1 else (pad_len,)
                padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
                return torch.cat([tensor, padding], dim=0)
            return tensor

        if use_precomputed:
            # Now create hidden states to match the final filtered length
            if features[0]['train']:
                precomputed_inputs_embeds = self.precomputed_train_data['inputs_embeds']
            else:
                precomputed_inputs_embeds = self.precomputed_eval_data['inputs_embeds']

            hidden_states_list = []
            for i, f in enumerate(features):
                idx = f["precomputed_idx"]
                hs = precomputed_inputs_embeds[idx].squeeze(0).to(self.training_device, dtype=self.base_model.dtype)  # Remove batch dim: [seq_len, hidden_dim]

                # Pad to max_len
                padded_hs = pad_tensor(hs, max_len)
                hidden_states_list.append(padded_hs)

            hidden_states = torch.stack(hidden_states_list).to(self.training_device)

            # Pad other tensors
            attention_mask_tensor = torch.stack([pad_tensor(seq[0], max_len) for seq in filtered_data])
            labels_tensor = torch.stack([pad_tensor(seq[1], max_len, self.label_pad_token_id) for seq in filtered_data])
            bin_idx_tensor = torch.stack([pad_tensor(seq[2], max_len, self.label_pad_token_id) for seq in filtered_data])
            section_mask_tensor = torch.stack([pad_tensor(seq[3], max_len, -100) for seq in filtered_data])
            position_ids_tensor = torch.stack([pad_tensor(seq[4], max_len, 0) for seq in filtered_data])
            section_labels_tensor = torch.stack([pad_tensor(seq[5], max_len, -100) for seq in filtered_data])
        else:
            # Regular path - all tensors already filtered
            hidden_states = torch.stack([pad_tensor(seq[0], max_len) for seq in filtered_data])
            attention_mask_tensor = torch.stack([pad_tensor(seq[1], max_len) for seq in filtered_data])
            labels_tensor = torch.stack([pad_tensor(seq[2], max_len, self.label_pad_token_id) for seq in filtered_data])
            bin_idx_tensor = torch.stack([pad_tensor(seq[3], max_len, self.label_pad_token_id) for seq in filtered_data])
            section_mask_tensor = torch.stack([pad_tensor(seq[4], max_len, -100) for seq in filtered_data])
            position_ids_tensor = torch.stack([pad_tensor(seq[5], max_len, 0) for seq in filtered_data])
            section_labels_tensor = torch.stack([pad_tensor(seq[6], max_len, -100) for seq in filtered_data])

        # Ensure gradients are enabled
        hidden_states = hidden_states.detach().requires_grad_(True)

        # Prepare the return batch
        batch = {
            "inputs_embeds": hidden_states,
            "attention_mask": attention_mask_tensor,
            "labels": labels_tensor,
            "bin_idx": bin_idx_tensor,
            "section_mask": section_mask_tensor,
            "position_ids": position_ids_tensor,
            "section_labels": section_labels_tensor,
        }

        # Preserve any additional fields that don't need padding (like prompt_idx)
        if features:
            excluded_keys = ["input_ids", "labels", "bin_idx", "section_labels", "attention_mask", "position_ids", "inputs_embeds", "section_mask"]
            if use_precomputed:
                excluded_keys.append("precomputed_idx")
                excluded_keys.append("train")

            for key in features[0].keys():
                if key not in excluded_keys:
                    values = [f[key] for f in features]
                    if isinstance(values[0], (int, float)):
                        tensor_val = torch.tensor(values).to(self.training_device)
                        batch[key] = tensor_val
                    else:
                        batch[key] = values

        return batch


    def _generate_position_ids(self, padded_section_labels: List[List[int]]) -> torch.Tensor:
        """
        Generate position IDs where append_text and rollout tokens are ignored.
        
        Position IDs are continuous for prompt (0), paragraph (1), and delimiter (2) tokens.
        Append_text (3) and rollout (4) tokens don't advance the position counter.
        """
        position_ids = []
        
        for section_seq in padded_section_labels:
            pos_ids = []
            current_pos = 0
            append_pos = 0
            
            for section_label in section_seq:
                if section_label in [0, 1, 2]:  # prompt, paragraph, delimiter
                    append_pos = 0
                    pos_ids.append(current_pos)
                    current_pos += 1
                elif section_label in [3, 4]:  # append_text, rollout
                    pos_ids.append(current_pos + append_pos)  # Don't advance position
                    append_pos += 1
                else:  # padding (-100)
                    pos_ids.append(current_pos)  # Use 0 for padding positions
                    current_pos += 1
            
            position_ids.append(pos_ids)
        
        return torch.tensor(position_ids, dtype=torch.long)
    
    def _generate_special_attention_mask(self, padded_section_labels: List[List[int]]) -> torch.Tensor:
        """
        Generate section-based attention mask that assigns unique IDs to different rollout sections.
        
        Creates a mask where:
        - Section 0,1,2 (prompt, paragraph, delimiter) tokens get ID 1
        - Section 3,4 (append_text, rollout) tokens get incrementing IDs (2, 3, 4, ...)
        - Padding tokens get ID 0
        
        This mask is used to control attention patterns so that rollout tokens can only attend
        to tokens from the same rollout or to the original text (ID 1).
        
        Args:
            padded_section_labels: List of sequences containing section labels for each token
            seq_len: Sequence length (unused but kept for interface consistency)
            
        Returns:
            torch.Tensor: Attention mask with shape (batch_size, seq_len) containing section IDs
        """
        attention_mask = []

        for section_seq in padded_section_labels:
            attn_mask = []
            rollout_idx = 0
            for i, section_label in enumerate(section_seq):
                if section_label in [0, 1, 2]:
                    attn_mask.append(1) # prompt, paragraph, delimiter is 1
                    if i>0 and section_seq[i-1] in [3, 4]:
                        rollout_idx += 1
                elif section_label in [3, 4]:
                    attn_mask.append(rollout_idx + 2) # first rollout is 2, second is 3, etc.
                else:
                    attn_mask.append(0) # padding is 0
            attention_mask.append(attn_mask)
        return torch.tensor(attention_mask, dtype=torch.long)
    

def create_block_mask_for_base_model(section_mask: torch.Tensor, device: torch.device):
    B, T = section_mask.shape

    def mask_function(b, _, q_idx, kv_idx):
        causal = q_idx >= kv_idx
        key_valid = section_mask[b, kv_idx] != 0
        
        # Section-based attention rules:
        # 1's can only attend to 1's
        # 2,3,4,... can only attend to 1's and itself
        query_section = section_mask[b, q_idx]
        key_section = section_mask[b, kv_idx]
        section_valid = (key_section==query_section) | ((query_section>1) & (key_section==1))
        
        return causal & key_valid & section_valid
    
    # Finally create the block mask
    block_mask = create_block_mask(
        mask_mod=mask_function,
        B=B,
        H=None,
        Q_LEN=T,
        KV_LEN=T,
        device=device,
        _compile=True,
    )

    return block_mask
