import copy

from lm_polygraph.stat_calculators.stat_calculator import StatCalculator
from lm_polygraph.utils.generation_parameters import GenerationParameters
from lm_polygraph.utils.model import Model
from transformers import GenerationConfig

from luh import AutoUncertaintyHead
import gc
from typing import Dict, List, Tuple
import torch
import time
import numpy as np
import os
import logging

log = logging.getLogger()


class SampleGenerationCalculator(StatCalculator):
    def __init__(
            self,
            uncertainty_head,
            n_alternatives=10,
            tokenize=True,
            args_generate=dict(),
            predict_token_uncertainties=True,
            device="cuda",
            top_k: int = 50,
            top_p: float = 0.95,
            temperature: float = 1.0,
            batch_processing: bool = True,  # Add batch processing flag
    ):
        super().__init__()

        self.n_alternatives = n_alternatives
        self._tokenize = tokenize
        self.args_generate = args_generate

        self.uncertainty_head = uncertainty_head.to(device)
        self.uncertainty_head.eval()
        self.output_attentions = self.uncertainty_head.output_attentions
        self.predict_token_uncertainties = predict_token_uncertainties

        self.top_k = top_k
        self.top_p = top_p
        self.temperature = temperature
        self.batch_processing = batch_processing  # Store batch processing flag
    
    def _move_generation_output_to_cpu(self, out):
        """
        Move all tensors in GenerateDecoderOnlyOutput to CPU.
        
        Args:
            out: GenerateDecoderOnlyOutput object containing various tensor attributes
            
        Returns:
            GenerateDecoderOnlyOutput with all tensors moved to CPU
        """
        # Create a copy of the output object to avoid modifying the original
        import copy
        from dataclasses import fields
        
        # Get the class of the output object
        output_class = type(out)
        
        # Dictionary to store the new field values
        new_values = {}
        
        # Process each field in the dataclass
        for field in fields(out):
            field_value = getattr(out, field.name)
            
            if field_value is None:
                new_values[field.name] = None
            elif field.name == 'sequences':
                # sequences: torch.LongTensor
                new_values[field.name] = field_value.cpu() if torch.is_tensor(field_value) else field_value
            elif field.name in ['scores', 'logits']:
                # scores/logits: Optional[tuple[torch.FloatTensor]]
                if field_value is not None:
                    new_values[field.name] = tuple(tensor.cpu() for tensor in field_value)
                else:
                    new_values[field.name] = None
            elif field.name in ['attentions', 'hidden_states']:
                # attentions/hidden_states: Optional[tuple[tuple[torch.FloatTensor]]]
                if field_value is not None:
                    new_values[field.name] = tuple(
                        tuple(tensor.cpu() for tensor in inner_tuple)
                        for inner_tuple in field_value
                    )
                else:
                    new_values[field.name] = None
            elif field.name == 'past_key_values':
                # past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]]
                if field_value is not None:
                    new_values[field.name] = tuple(
                        tuple(
                            tuple(tensor.cpu() for tensor in innermost_tuple)
                            for innermost_tuple in middle_tuple
                        )
                        for middle_tuple in field_value
                    )
                else:
                    new_values[field.name] = None
            else:
                # For any other fields, just copy as-is
                new_values[field.name] = field_value
        
        # Create new instance with moved tensors
        return output_class(**new_values)
    
    def _move_generation_output_to_cuda(self, out, device):
        """
        Move all tensors in GenerateDecoderOnlyOutput to CUDA device.
        
        Args:
            out: GenerateDecoderOnlyOutput object containing various tensor attributes
            device: Target CUDA device (e.g., 'cuda', 'cuda:0', 'cuda:1')
            
        Returns:
            GenerateDecoderOnlyOutput with all tensors moved to specified CUDA device
        """
        # Create a copy of the output object to avoid modifying the original
        import copy
        from dataclasses import fields
        
        # Get the class of the output object
        output_class = type(out)
        
        # Dictionary to store the new field values
        new_values = {}
        
        # Process each field in the dataclass
        for field in fields(out):
            field_value = getattr(out, field.name)
            
            if field_value is None:
                new_values[field.name] = None
            elif field.name == 'sequences':
                # sequences: torch.LongTensor
                new_values[field.name] = field_value.to(device) if torch.is_tensor(field_value) else field_value
            elif field.name in ['scores', 'logits']:
                # scores/logits: Optional[tuple[torch.FloatTensor]]
                if field_value is not None:
                    new_values[field.name] = tuple(tensor.to(device) for tensor in field_value)
                else:
                    new_values[field.name] = None
            elif field.name in ['attentions', 'hidden_states']:
                # attentions/hidden_states: Optional[tuple[tuple[torch.FloatTensor]]]
                if field_value is not None:
                    new_values[field.name] = tuple(
                        tuple(tensor.to(device) for tensor in inner_tuple)
                        for inner_tuple in field_value
                    )
                else:
                    new_values[field.name] = None
            elif field.name == 'past_key_values':
                # past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]]
                if field_value is not None:
                    new_values[field.name] = tuple(
                        tuple(
                            tuple(tensor.to(device) for tensor in innermost_tuple)
                            for innermost_tuple in middle_tuple
                        )
                        for middle_tuple in field_value
                    )
                else:
                    new_values[field.name] = None
            else:
                # For any other fields, just copy as-is
                new_values[field.name] = field_value
        
        # Create new instance with moved tensors
        return output_class(**new_values)

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        return [
            "hidden_states",
            "greedy_log_probs",
            "greedy_logits",
            "greedy_tokens",
            "greedy_tokens_alternatives",
            "greedy_texts",
            "greedy_log_likelihoods",
            "uncertainty_logits",
            "uhead_features",
            "input_texts",
            "input_tokens",
        ], []

    def postprocess_predictions(self, batch, out, tokenizer):
        logits = torch.stack(out.scores, dim=1)
        sequences = out.sequences

        cut_logits, cut_sequences, cut_texts, cut_alternatives, ll = [], [], [], [], []
        for i in range(batch['input_ids'].shape[0]):
            idx = batch["input_ids"].shape[1]
            seq = sequences[i, idx:].cpu()
            length = next((j + 1 for j, token in enumerate(seq) if token == tokenizer.eos_token_id), len(seq))
            cut_seq = seq[:length]
            cut_sequences.append(cut_seq.tolist())
            cut_texts.append(tokenizer.decode(cut_seq))
            cut_logits.append(logits[i, :length, :].cpu().numpy())

            alt = []
            for j in range(length):
                lt = logits[i, j, :].cpu().numpy()
                best_tokens = np.argpartition(lt, -self.n_alternatives)[-self.n_alternatives:]
                best_tokens = best_tokens[np.argsort(-lt[best_tokens])]
                alt_j = [(t.item(), lt[t].item()) for t in best_tokens]
                alt_j.sort(key=lambda x: x[0] == cut_seq[j].item(), reverse=True)
                alt.append(alt_j)
            cut_alternatives.append(alt)
            ll.append([cut_logits[-1][j, cut_seq[j]] for j in range(len(cut_seq))])

        return {
            "input_tokens": batch["input_ids"].to("cpu").tolist(),
            "greedy_log_probs": cut_logits,
            "greedy_tokens": cut_sequences,
            "greedy_tokens_alternatives": cut_alternatives,
            "greedy_texts": cut_texts,
            "greedy_log_likelihoods": ll,
            "logits": logits[:, :-1, :],
        }

    def __call__single_batch(self, dependencies: Dict[str, np.array], texts: List[str], model: Model, max_new_tokens: int = 100,
                 **kwargs) -> Dict[str, np.ndarray]:
        """Single batch processing without adaptive batch sizing."""
        cache = None

        batch = model.tokenize(texts) if self._tokenize else texts
        device_batch = batch.to(model.device())
        log.info(f"[Single batch mode] Generating {max_new_tokens} new tokens on device={model.device()}...")

        # Overwrite new parameters
        old_params: GenerationParameters = model.generation_parameters
        params = copy.deepcopy(old_params)
        params.top_p, params.top_k, params.temperature = self.top_p, self.top_k, self.temperature
        model.generation_parameters = params

        start_time = time.time()
        with torch.no_grad():
            out = model.generate(
                **device_batch,
                output_scores=True,
                return_dict_in_generate=True,
                output_attentions=self.output_attentions,
                output_hidden_states=True,
                do_sample=True,
                suppress_tokens=(
                    []
                    if model.generation_parameters.allow_newlines
                    else [
                        t
                        for t in range(len(model.tokenizer))
                        if "\n" in model.tokenizer.decode([t])
                    ]
                ),
                pad_token_id=model.tokenizer.eos_token_id,
                tokenizer=model.tokenizer,
                **self.args_generate,
            )
        model.generation_parameters = old_params
        log.info(f"Done generating in {round(time.time() - start_time, 2)} seconds")

        result_dict = self.postprocess_predictions(batch, out, model.tokenizer)
        result_dict["input_texts"] = texts

        if cache:
            for i in range(len(texts)):
                cache.get(texts[i], lambda: result_dict["greedy_tokens"][i])

        output_bounds = []
        full_attn_mask = torch.zeros_like(out.sequences).bool()
        for i in range(batch['input_ids'].shape[0]):
            idx = batch["input_ids"].shape[1]
            full_attn_mask[i, :idx] = batch["attention_mask"][i]
            length = len(result_dict["greedy_tokens"][i])
            full_attn_mask[i][idx: idx + length] = 1
            output_bounds.append((idx - 1, idx + length - 1))

        out["full_attention_mask"] = full_attn_mask
        out["context_lengths"] = torch.tensor([len(it) for it in batch["input_ids"]])
        batch["context_lenghts"] = out["context_lengths"]

        if self.predict_token_uncertainties:
            with torch.no_grad():
                uncertainty_logits = self.uncertainty_head(batch, out)
                result_dict["uncertainty_logits"] = [
                    ue[output_bounds[i][0]: output_bounds[i][1]]
                    for i, ue in enumerate(uncertainty_logits.cpu().detach().squeeze(-1))
                ]
        else:
            result_dict["uhead_features"] = self.uncertainty_head.feature_extractor(batch, out)
            result_dict["llm_inputs"] = batch
            result_dict["full_attention_mask"] = full_attn_mask

        return result_dict

    # def __call__(self, dependencies: Dict[str, np.array], texts: List[str], model: Model, max_new_tokens: int = 100,
    #              **kwargs) -> Dict[str, np.ndarray]:
    #     cache = None

    #     batch = model.tokenize(texts) if self._tokenize else texts
    #     device_batch = batch.to(model.device())
    #     log.info(f"Generating {max_new_tokens} new tokens on device={model.device()}...")

    #     # Overwrite new parameters
    #     old_params: GenerationParameters = model.generation_parameters
    #     params = copy.deepcopy(old_params)
    #     params.top_p, params.top_k, params.temperature = self.top_p, self.top_k, self.temperature
    #     model.generation_parameters = params

    #     start_time = time.time()
    #     with torch.no_grad():
    #         out = model.generate(
    #             **device_batch,
    #             output_scores=True,
    #             return_dict_in_generate=True,
    #             output_attentions=self.output_attentions,
    #             output_hidden_states=True,
    #             do_sample=True,
    #             suppress_tokens=(
    #                 []
    #                 if model.generation_parameters.allow_newlines
    #                 else [
    #                     t
    #                     for t in range(len(model.tokenizer))
    #                     if "\n" in model.tokenizer.decode([t])
    #                 ]
    #             ),
    #             pad_token_id=model.tokenizer.eos_token_id,
    #             tokenizer=model.tokenizer,
    #             **self.args_generate,
    #         )
    #     model.generation_parameters = old_params
    #     log.info(f"Done generating in {round(time.time() - start_time, 2)} seconds")

    #     result_dict = self.postprocess_predictions(batch, out, model.tokenizer)
    #     result_dict["input_texts"] = texts

    #     if cache:
    #         for i in range(len(texts)):
    #             cache.get(texts[i], lambda: result_dict["greedy_tokens"][i])

    #     output_bounds = []
    #     full_attn_mask = torch.zeros_like(out.sequences).bool()
    #     for i in range(batch['input_ids'].shape[0]):
    #         idx = batch["input_ids"].shape[1]
    #         full_attn_mask[i, :idx] = batch["attention_mask"][i]
    #         length = len(result_dict["greedy_tokens"][i])
    #         full_attn_mask[i][idx: idx + length] = 1
    #         output_bounds.append((idx - 1, idx + length - 1))

    #     out["full_attention_mask"] = full_attn_mask
    #     out["context_lengths"] = torch.tensor([len(it) for it in batch["input_ids"]])
    #     batch["context_lenghts"] = out["context_lengths"]

    #     if self.predict_token_uncertainties:
    #         with torch.no_grad():
    #             uncertainty_logits = self.uncertainty_head(batch, out)
    #             result_dict["uncertainty_logits"] = [
    #                 ue[output_bounds[i][0]: output_bounds[i][1]]
    #                 for i, ue in enumerate(uncertainty_logits.cpu().detach().squeeze(-1))
    #             ]
    #     else:
    #         result_dict["uhead_features"] = self.uncertainty_head.feature_extractor(batch, out)
    #         result_dict["llm_inputs"] = batch
    #         result_dict["full_attention_mask"] = full_attn_mask

    #     return result_dict

    def __call__(self, dependencies: Dict[str, np.array], texts: List[str], model: Model, max_new_tokens: int = 100,
                 **kwargs) -> Dict[str, np.ndarray]:
        # Choose between batch processing and single batch mode
        if not self.batch_processing:
            return self.__call__single_batch(dependencies, texts, model, max_new_tokens, **kwargs)
        
        # Continue with batch processing mode
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
        cache = None

        batch = model.tokenize(texts) if self._tokenize else texts
        n_samples = len(texts)
        log.info(f"Generating {max_new_tokens} new tokens for {n_samples} samples on device={model.device()}...")

        # Overwrite new parameters
        old_params: GenerationParameters = model.generation_parameters
        params = copy.deepcopy(old_params)
        params.top_p, params.top_k, params.temperature = self.top_p, self.top_k, self.temperature
        model.generation_parameters = params

        # Try generation with adaptive batch size: n -> n/2 -> 1
        # batch_sizes = [n_samples]
        # if n_samples > 2:
        #     batch_sizes.append(n_samples // 2)
        # if n_samples > 1:
        #     batch_sizes.append(1)
        # batch_sizes = [2]
        # batch_sizes = [10]
        batch_sizes= [2,1]
        
        all_outputs = []
        generation_succeeded = False
        
        for batch_size in batch_sizes:
            try:
                log.info(f"Attempting generation with batch size {batch_size}...")
                start_time = time.time()
                total_batches = (n_samples + batch_size - 1) // batch_size  # Ceiling division
                batch_num = 0
                # Process in batches
                for i in range(0, n_samples, batch_size):
                    end_idx = min(i + batch_size, n_samples)
                    batch_texts = texts[i:end_idx]
                    batch_start_time = time.time()
                    log.info(f"Processing batch {batch_num + 1} out of total batches {total_batches}")
                    
                    # Tokenize current batch
                    if self._tokenize:
                        current_batch = model.tokenize(batch_texts)
                    else:
                        current_batch = {k: v[i:end_idx] for k, v in batch.items()}
                    
                    device_batch = current_batch.to(model.device())
                    
                    with torch.no_grad():
                        out = model.generate(
                            **device_batch,
                            output_scores=True,
                            return_dict_in_generate=True,
                            output_attentions=self.output_attentions,
                            output_hidden_states=True,
                            do_sample=True,
                            suppress_tokens=(
                                []
                                if model.generation_parameters.allow_newlines
                                else [
                                    t
                                    for t in range(len(model.tokenizer))
                                    if "\n" in model.tokenizer.decode([t])
                                ]
                            ),
                            pad_token_id=model.tokenizer.eos_token_id,
                            tokenizer=model.tokenizer,
                            **self.args_generate,
                        )
                    log.info(f"Batch {batch_num + 1} generated in {round(time.time() - batch_start_time, 2)} seconds")
                    
                    # Log GPU memory usage before offloading
                    if torch.cuda.is_available():
                        gpu_mem_before = torch.cuda.memory_allocated() / 1024**3  # GB
                        log.info(f"GPU memory before offloading: {gpu_mem_before:.2f} GB")
                    
                    # # Offload generation outputs to CPU to save GPU memory
                    log.info("Offloading generation outputs to CPU to save GPU memory...")
                    out_cpu = self._move_generation_output_to_cpu(out)
                    # out = out.to('cpu')
                    current_batch_cpu = current_batch.to('cpu')
                    
                    # Delete the original GPU tensors to actually free memory
                    del out, device_batch, current_batch
                    
                    # Log GPU memory usage after offloading
                    if torch.cuda.is_available():
                        gpu_mem_after = torch.cuda.memory_allocated() / 1024**3  # GB
                        log.info(f"GPU memory after offloading: {gpu_mem_after:.2f} GB (saved: {gpu_mem_before - gpu_mem_after:.2f} GB)")
                    
                    all_outputs.append((out_cpu, current_batch_cpu))
                    # all_outputs.append(out)                    
                    # Clear cache after each batch to prevent memory accumulation
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                        gc.collect()
                    
                    batch_num += 1  # Increment batch counter
                
                log.info(f"Successfully generated with batch size {batch_size} in {round(time.time() - start_time, 2)} seconds")
                generation_succeeded = True
                break  # Success, exit the retry loop
                
            except torch.cuda.OutOfMemoryError as e:
                log.warning(f"OOM with batch size {batch_size}: {e}")
                torch.cuda.empty_cache()
                gc.collect()
                
                # Clear any partial outputs
                all_outputs = []
                
                # Try next smaller batch size
                if batch_size == batch_sizes[-1]:
                    # Even smallest batch size failed, re-raise the error
                    raise
        
        model.generation_parameters = old_params
        
        if not generation_succeeded:
            raise RuntimeError("Generation failed with all batch sizes")
        
        # Process each batch independently and collect results
        all_cut_logits = []
        all_cut_sequences = []
        all_cut_texts = []
        all_cut_alternatives = []
        all_ll = []
        all_input_tokens = []
        all_uncertainty_logits = []
        all_uhead_features = []  # Change to list instead of None
        all_full_attention_masks = []  # Store attention masks as list
        all_modified_batches = []  # Store modified batches with context_lengths
        
        # Process each batch through the entire pipeline
        for out_cpu, current_batch_cpu in all_outputs:
            log.info("Moving batch outputs back to CUDA for uncertainty head processing...")
            
            # # Log GPU memory before moving back to CUDA
            # if torch.cuda.is_available():
            #     gpu_mem_before_cuda = torch.cuda.memory_allocated() / 1024**3  # GB
            #     log.info(f"GPU memory before moving to CUDA: {gpu_mem_before_cuda:.2f} GB")
            
            # # Move data back to CUDA for uncertainty head processing
            # out = self._move_generation_output_to_cuda(out_cpu, model.device())
            # current_batch = current_batch_cpu.to(model.device())
            
            out = out_cpu
            current_batch = current_batch_cpu
            # # Log GPU memory after moving back to CUDA
            # if torch.cuda.is_available():
            #     gpu_mem_after_cuda = torch.cuda.memory_allocated() / 1024**3  # GB
            #     log.info(f"GPU memory after moving to CUDA: {gpu_mem_after_cuda:.2f} GB")
            
            # Step 1: Get basic predictions for this batch
            batch_results = self.postprocess_predictions(current_batch, out, model.tokenizer)
            
            # Collect basic results
            batch_cut_logits = batch_results["greedy_log_probs"]
            batch_cut_sequences = batch_results["greedy_tokens"] # needed by StepsExtractor
            batch_cut_texts = batch_results["greedy_texts"] # needed by StepsExtractor
            batch_cut_alternatives = batch_results["greedy_tokens_alternatives"]
            batch_ll = batch_results["greedy_log_likelihoods"] # needed by MaximumSequenceProbability, MaximumTokenProbability, perplexity, entropy
            batch_input_tokens = batch_results["input_tokens"]
            
            all_cut_logits.extend(batch_cut_logits)
            all_cut_sequences.extend(batch_cut_sequences)
            all_cut_texts.extend(batch_cut_texts)
            all_cut_alternatives.extend(batch_cut_alternatives)
            all_ll.extend(batch_ll)
            all_input_tokens.extend(batch_input_tokens)
            
            # Step 2: Create output bounds and attention mask for this batch
            batch_size = current_batch['input_ids'].shape[0]
            output_bounds = []
            full_attn_mask = torch.zeros_like(out.sequences).bool()
            
            for i in range(batch_size):
                idx = current_batch["input_ids"].shape[1]
                full_attn_mask[i, :idx] = current_batch["attention_mask"][i]
                length = len(batch_cut_sequences[i])
                full_attn_mask[i][idx: idx + length] = 1
                output_bounds.append((idx - 1, idx + length - 1))
            
            # Step 3: Set required attributes for uncertainty head
            out["full_attention_mask"] = full_attn_mask
            out["context_lengths"] = torch.tensor([len(it) for it in current_batch["input_ids"]])
            current_batch["context_lenghts"] = out["context_lengths"]
            # # import pdb; pdb.set_trace()
            
            # Step 4: Process uncertainty head for this batch
            if self.predict_token_uncertainties:
                with torch.no_grad():
                    uncertainty_logits = self.uncertainty_head(current_batch, out)
                    batch_uncertainty_logits = [
                        ue[output_bounds[i][0]: output_bounds[i][1]]
                        for i, ue in enumerate(uncertainty_logits.cpu().detach().squeeze(-1))
                    ]
                    all_uncertainty_logits.extend(batch_uncertainty_logits)
            else:
                batch_features = self.uncertainty_head.feature_extractor(current_batch, out)
                # Store features as list instead of concatenating
                all_uhead_features.append(batch_features)
                # Store the full attention mask for this batch
                all_full_attention_masks.append(full_attn_mask)
            # # import pdb; pdb.set_trace()
            # Store the modified batch (move to CPU to save memory)
            all_modified_batches.append({k: v.cpu() if torch.is_tensor(v) else v for k, v in current_batch.items()})
            
            # Clear CUDA tensors after processing to free memory
            del out, current_batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                gc.collect()
        # # import pdb; pdb.set_trace()
        # Step 5: Build the final result dictionary
        result_dict = {
            "input_tokens": all_input_tokens,
            "greedy_log_probs": all_cut_logits,
            "greedy_tokens": all_cut_sequences,
            "greedy_tokens_alternatives": all_cut_alternatives,
            "greedy_texts": all_cut_texts,
            "greedy_log_likelihoods": all_ll,
            "input_texts": texts
        }
        
        if self.predict_token_uncertainties:
            result_dict["uncertainty_logits"] = all_uncertainty_logits
        else:
            # Store features and attention masks as lists
            result_dict["uhead_features"] = all_uhead_features
            result_dict["full_attention_mask"] = all_full_attention_masks
            # Store llm_inputs as list of batches
            result_dict["llm_inputs"] = all_modified_batches

        if cache:
            for i in range(len(texts)):
                cache.get(texts[i], lambda: result_dict["greedy_tokens"][i])

        return result_dict



def load_stat_calculator(config, builder):
    uncertainty_head = AutoUncertaintyHead.from_pretrained(
        config.uq_head_path,
        builder.model.model)
    builder.uncertainty_head = uncertainty_head
    return SampleGenerationCalculator(
        uncertainty_head=uncertainty_head,
        tokenize=True,
        args_generate=config.args_generate,
        predict_token_uncertainties=config.predict_token_uncertainties,
        batch_processing=getattr(config, 'batch_processing', True)  # Default to True if not specified
    )
