import logging
from typing import List, Tuple, Dict
from dataclasses import dataclass

import torch
from transformers import PreTrainedTokenizer
from tevatron.co_retriever.arguments import DataArguments
from nltk import sent_tokenize

logger = logging.getLogger(__name__)

@dataclass
class JointTrainCollator:
    """
    Collator for a joint training setup that produces four distinct tensor groups.

    This collator prepares separate inputs for the contrastive loss and two parallel
    (retriever/LLM) inputs for the Revela loss.
    """
    data_args: DataArguments
    retriever_tokenizer: PreTrainedTokenizer
    llm_tokenizer: PreTrainedTokenizer

    def _prepare_revela_texts(
        self,
        current_query_list: List[str], 
        current_raw_query_list: List[str], 
        current_passages: List[str],
        current_raw_passages: List[str]
    ) -> Tuple[List[str], List[str]]:
        """
        Helper function to construct the text lists for Revela inputs.
        Handles the `chunk_neg` logic if enabled.
        """
        
        # --- CHUNK_NEG LOGIC ---
        if self.data_args.chunk_neg:
            if not current_raw_passages:
                return current_query_list, current_raw_query_list
            
            # Assume first passage is positive
            positive_passage = current_passages[0:1]
            raw_positive_passage = current_raw_passages[0:1]
            
            # Get raw negatives to chunk
            raw_negative_passages = current_raw_passages[1:]
            
            # --- NEW: Separate lists for chunked and unchunkable negatives ---
            raw_chunked_negatives = []
            raw_unchunkable_negatives = []
            
            for raw_neg_doc in raw_negative_passages:
                # Use NLTK for sentence tokenization
                sentences = sent_tokenize(raw_neg_doc)
                
                if len(sentences) < 2:
                    # Not enough sentences to chunk, add to the 'end' list
                    if raw_neg_doc.strip():
                        raw_unchunkable_negatives.append(raw_neg_doc)
                    continue

                # --- Find split point to balance length (same as before) ---
                total_len = len(raw_neg_doc)
                mid_len = total_len // 2
                current_len = 0
                split_index = 1 # Default: split after 1st sentence

                for i, sentence in enumerate(sentences):
                    # Estimate length including a space
                    sentence_len = len(sentence) + 1 
                    if current_len + sentence_len > mid_len:
                        # This sentence crosses the midpoint
                        len_before = current_len
                        len_after = current_len + sentence_len
                        if abs(mid_len - len_before) < abs(mid_len - len_after) and i > 0:
                            split_index = i # Split *before* this sentence
                        else:
                            split_index = i + 1 # Split *after* this sentence
                        break
                    current_len += sentence_len
                
                # Ensure we don't create an empty chunk if possible
                if split_index == 0: 
                    split_index = 1
                if split_index >= len(sentences) and len(sentences) > 0: 
                    split_index = len(sentences) - 1 # Avoid splitting into empty second chunk
                
                if split_index == 0: # Handle case of 1 sentence
                     if sentences[0].strip():
                         # This case should be caught by len(sentences) < 2, but as a safeguard:
                        raw_unchunkable_negatives.append(sentences[0])
                     continue

                chunk1_raw = " ".join(sentences[:split_index])
                chunk2_raw = " ".join(sentences[split_index:])
                
                # --- NEW: Add chunks to the 'chunked' list ---
                if chunk1_raw.strip():
                    raw_chunked_negatives.append(chunk1_raw)
                if chunk2_raw.strip():
                    raw_chunked_negatives.append(chunk2_raw)

            # --- NEW: Combine lists, with chunked ones first ---
            final_raw_negatives = raw_chunked_negatives + raw_unchunkable_negatives

            # Apply prefixes to the newly created raw chunks
            prefix = self.data_args.passage_prefix if self.data_args.passage_prefix else ""
            chunked_negatives = [f"{prefix}{p}" for p in final_raw_negatives] # Renamed from chunked_negatives

            # Final list: query + positive + (chunked negatives + unchunkable negatives)
            all_revela_texts = current_query_list + positive_passage + chunked_negatives
            all_raw_revela_texts = current_raw_query_list + raw_positive_passage + final_raw_negatives
            
            return all_revela_texts, all_raw_revela_texts
        
        # --- ORIGINAL LOGIC ---
        else:
            all_revela_texts = current_query_list + current_passages
            all_raw_revela_texts = current_raw_query_list + current_raw_passages
            return all_revela_texts, all_raw_revela_texts

    def __call__(self, features: List[Tuple[str, List[str]]]) -> Dict[str, Dict]:
        # 1. Unpack raw text and apply prefixes
        raw_queries = [f[0] for f in features]
        raw_passages = [p for f in features for p in f[1]]
        
        queries = [f"{self.data_args.query_prefix}{q}" for q in raw_queries] if self.data_args.query_prefix else raw_queries
        passages = [f"{self.data_args.passage_prefix}{p}" for p in raw_passages] if self.data_args.passage_prefix else raw_passages

        # --- Part 1: Prepare separate inputs for Contrastive Learning ---
        
        q_collated = self.retriever_tokenizer(
            queries,
            padding=False,
            truncation=True,
            return_attention_mask=False,
            return_token_type_ids=False,
            max_length=self.data_args.query_max_len - (1 if self.data_args.append_eos_token else 0),
            add_special_tokens=True # REVISED: Explicitly added for clarity
        )
        p_collated = self.retriever_tokenizer(
            passages,
            padding=False,
            truncation=True,
            return_attention_mask=False,
            return_token_type_ids=False,
            max_length=self.data_args.passage_max_len - (1 if self.data_args.append_eos_token else 0),
            add_special_tokens=True # REVISED: Explicitly added for clarity
        )

        if self.data_args.append_eos_token:
            q_collated['input_ids'] = [q + [self.retriever_tokenizer.eos_token_id] for q in q_collated['input_ids']]
            p_collated['input_ids'] = [p + [self.retriever_tokenizer.eos_token_id] for p in p_collated['input_ids']]

        query_contrastive_padded = self.retriever_tokenizer.pad(
            q_collated, padding=True, return_tensors='pt', return_attention_mask=True
        )
        passage_contrastive_padded = self.retriever_tokenizer.pad(
            p_collated, padding=True, return_tensors='pt', return_attention_mask=True
        )

        # --- Part 2: Prepare merged inputs for Revela ---

        # REVISED: Use the same prefixed texts for both retriever and LLM parts of Revela for consistency
        p_over_q = len(passages) // len(queries)
        revela_retriever_input_padded_list = []
        revela_llm_input_padded_list = []
        for idx in range(len(queries)):
            current_passages_list = passages[idx * p_over_q : (idx + 1) * p_over_q]
            current_raw_passages_list = raw_passages[idx * p_over_q : (idx + 1) * p_over_q]

            all_revela_texts, all_raw_revela_texts = self._prepare_revela_texts(
                current_query_list=queries[idx:idx+1],
                current_raw_query_list=raw_queries[idx:idx+1],
                current_passages=current_passages_list,
                current_raw_passages=current_raw_passages_list
            )

            all_revela_texts = all_revela_texts[:self.data_args.top_k]
            all_raw_revela_texts = all_raw_revela_texts[:self.data_args.top_k]
        
            # A) Tokenize with RETRIEVER tokenizer for Revela embeddings
            revela_retriever_collated = self.retriever_tokenizer(
                all_revela_texts,
                padding=False,
                truncation=True,
                return_attention_mask=False,
                return_token_type_ids=False,
                max_length=self.data_args.passage_max_len - (1 if self.data_args.append_eos_token else 0),
                add_special_tokens=True # REVISED: Explicitly added for clarity
            )
            if self.data_args.append_eos_token:
                revela_retriever_collated['input_ids'] = [item + [self.retriever_tokenizer.eos_token_id] for item in revela_retriever_collated['input_ids']]
            
            revela_retriever_input_padded = self.retriever_tokenizer.pad(
                revela_retriever_collated, padding=True, return_tensors='pt', return_attention_mask=True
            )
    
            # B) Tokenize with LLM tokenizer for the reference model
            revela_llm_collated = self.llm_tokenizer(
                all_raw_revela_texts, # REVISED: Using consistent prefixed text
                padding=False,
                truncation=True, 
                max_length=self.data_args.passage_max_len,
                add_special_tokens=True
            )
    
            # Pad the LLM inputs and their corresponding labels
            revela_llm_input_padded = self.llm_tokenizer.pad(
                revela_llm_collated, padding=True, return_tensors='pt', return_attention_mask=True
            )
            # Create labels for Causal Language Modeling by ignoring pad tokens
            labels = revela_llm_input_padded["input_ids"].clone()
            pad_token_id = self.llm_tokenizer.pad_token_id
            if pad_token_id is not None:
                labels[labels == self.llm_tokenizer.pad_token_id] = -100
    
            revela_llm_input_padded['labels'] = labels
            
            revela_retriever_input_padded_list.append(revela_retriever_input_padded)
            revela_llm_input_padded_list.append(revela_llm_input_padded)
            
        return {
            "query_contrastive": query_contrastive_padded,
            "passage_contrastive": passage_contrastive_padded,
            "revela_retriever_input": revela_retriever_input_padded_list,
            "revela_llm_input": revela_llm_input_padded_list,
        }