"""
IMPORTANT NOTE: This file contains the MultiGPUSampleGenerationCalculator implementation
that attempts to parallelize generation within a single stat calculator.

This approach has been deprecated in favor of dataset-level parallelism where each GPU
runs the complete pipeline independently on a subset of the data.

The issue with this approach was that it tried to merge intermediate results (features)
from different batches, which led to dimension mismatches when different batches had
different maximum sequence lengths.

This file is kept for reference and potential future use if needed.
"""

import copy
import torch
import torch.nn as nn
import time
import numpy as np
import logging
import gc
from typing import Dict, List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from dataclasses import fields

from lm_polygraph.stat_calculators.stat_calculator import StatCalculator
from lm_polygraph.utils.generation_parameters import GenerationParameters
from lm_polygraph.utils.model import Model
from luh import AutoUncertaintyHead

log = logging.getLogger()


class MultiGPUSampleGenerationCalculator(StatCalculator):
    """
    Sample generation calculator with multi-GPU support.
    Distributes prompts across GPUs, each GPU processes its subset independently.
    """
    def __init__(
            self,
            model_pool,
            uncertainty_heads,
            n_alternatives=10,
            tokenize=True,
            args_generate=dict(),
            predict_token_uncertainties=True,
            gpu_ids=None,
            top_k: int = 50,
            top_p: float = 0.95,
            temperature: float = 1.0,
    ):
        super().__init__()
        
        self.model_pool = model_pool
        self.uncertainty_heads = uncertainty_heads  # Dict mapping gpu_id to uncertainty head
        self.n_alternatives = n_alternatives
        self._tokenize = tokenize
        self.args_generate = args_generate
        self.gpu_ids = gpu_ids or list(range(torch.cuda.device_count()))
        
        log.info(f"MultiGPU SampleGeneration using {len(self.gpu_ids)} GPUs: {self.gpu_ids}")
        
        # Use first uncertainty head for properties
        first_uhead = next(iter(self.uncertainty_heads.values()))
        self.output_attentions = first_uhead.output_attentions
        self.predict_token_uncertainties = predict_token_uncertainties
        
        self.top_k = top_k
        self.top_p = top_p
        self.temperature = temperature

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        return [
            "hidden_states",
            "greedy_log_probs",
            "greedy_logits",
            "greedy_tokens",
            "greedy_tokens_alternatives",
            "greedy_texts",
            "greedy_log_likelihoods",
            "uncertainty_logits",
            "uhead_features",
            "input_texts",
            "input_tokens",
        ], []

    def _move_generation_output_to_cpu(self, out):
        """Move all tensors in GenerateDecoderOnlyOutput to CPU."""
        output_class = type(out)
        new_values = {}
        
        for field in fields(out):
            field_value = getattr(out, field.name)
            
            if field_value is None:
                new_values[field.name] = None
            elif field.name == 'sequences':
                new_values[field.name] = field_value.cpu() if torch.is_tensor(field_value) else field_value
            elif field.name in ['scores', 'logits']:
                if field_value is not None:
                    new_values[field.name] = tuple(tensor.cpu() for tensor in field_value)
                else:
                    new_values[field.name] = None
            elif field.name in ['attentions', 'hidden_states']:
                if field_value is not None:
                    new_values[field.name] = tuple(
                        tuple(tensor.cpu() for tensor in inner_tuple)
                        for inner_tuple in field_value
                    )
                else:
                    new_values[field.name] = None
            elif field.name == 'past_key_values':
                if field_value is not None:
                    new_values[field.name] = tuple(
                        tuple(
                            tuple(tensor.cpu() for tensor in innermost_tuple)
                            for innermost_tuple in middle_tuple
                        )
                        for middle_tuple in field_value
                    )
                else:
                    new_values[field.name] = None
            else:
                new_values[field.name] = field_value
        
        return output_class(**new_values)

    def postprocess_predictions(self, batch, out, tokenizer):
        """Process model outputs to extract sequences and logits."""
        logits = torch.stack(out.scores, dim=1)
        sequences = out.sequences
        
        cut_logits, cut_sequences, cut_texts, cut_alternatives, ll = [], [], [], [], []
        for i in range(batch['input_ids'].shape[0]):
            idx = batch["input_ids"].shape[1]
            seq = sequences[i, idx:].cpu()
            length = next((j + 1 for j, token in enumerate(seq) if token == tokenizer.eos_token_id), len(seq))
            cut_seq = seq[:length]
            cut_sequences.append(cut_seq.tolist())
            cut_texts.append(tokenizer.decode(cut_seq))
            cut_logits.append(logits[i, :length, :].cpu().numpy())
            
            alt = []
            for j in range(length):
                lt = logits[i, j, :].cpu().numpy()
                best_tokens = np.argpartition(lt, -self.n_alternatives)[-self.n_alternatives:]
                best_tokens = best_tokens[np.argsort(-lt[best_tokens])]
                alt_j = [(t.item(), lt[t].item()) for t in best_tokens]
                alt_j.sort(key=lambda x: x[0] == cut_seq[j].item(), reverse=True)
                alt.append(alt_j)
            cut_alternatives.append(alt)
            ll.append([cut_logits[-1][j, cut_seq[j]] for j in range(len(cut_seq))])
        
        return {
            "input_tokens": batch["input_ids"].to("cpu").tolist(),
            "greedy_log_probs": cut_logits,
            "greedy_tokens": cut_sequences,
            "greedy_tokens_alternatives": cut_alternatives,
            "greedy_texts": cut_texts,
            "greedy_log_likelihoods": ll,
            "logits": logits[:, :-1, :],
        }

    def process_on_single_gpu(self, texts: List[str], model: Model, gpu_id: int, 
                            max_new_tokens: int = 100, **kwargs) -> Dict[str, np.ndarray]:
        """
        Process a subset of texts on a single GPU.
        This mirrors the logic from the original SampleGenerationCalculator.__call__ method.
        """
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
        cache = None
        
        device = f"cuda:{gpu_id}"
        uncertainty_head = self.uncertainty_heads[gpu_id]
        
        batch = model.tokenize(texts) if self._tokenize else texts
        n_samples = len(texts)
        log.info(f"GPU {gpu_id}: Generating {max_new_tokens} new tokens for {n_samples} samples on device={device}")
        
        # Overwrite generation parameters
        old_params = model.generation_parameters
        params = copy.deepcopy(old_params)
        params.top_p, params.top_k, params.temperature = self.top_p, self.top_k, self.temperature
        model.generation_parameters = params
        
        # Adaptive batch sizes
        batch_sizes = [2, 1]
        
        all_outputs = []
        generation_succeeded = False
        
        for batch_size in batch_sizes:
            try:
                log.info(f"GPU {gpu_id}: Attempting generation with batch size {batch_size}...")
                start_time = time.time()
                total_batches = (n_samples + batch_size - 1) // batch_size
                batch_num = 0
                
                # Process in batches
                for i in range(0, n_samples, batch_size):
                    end_idx = min(i + batch_size, n_samples)
                    batch_texts = texts[i:end_idx]
                    batch_start_time = time.time()
                    log.info(f"GPU {gpu_id}: Processing batch {batch_num + 1} out of total batches {total_batches}")
                    
                    # Tokenize current batch
                    if self._tokenize:
                        current_batch = model.tokenize(batch_texts)
                    else:
                        current_batch = {k: v[i:end_idx] for k, v in batch.items()}
                    
                    device_batch = current_batch.to(device)
                    
                    with torch.no_grad():
                        out = model.generate(
                            **device_batch,
                            output_scores=True,
                            return_dict_in_generate=True,
                            output_attentions=self.output_attentions,
                            output_hidden_states=True,
                            do_sample=True,
                            suppress_tokens=(
                                []
                                if model.generation_parameters.allow_newlines
                                else [
                                    t
                                    for t in range(len(model.tokenizer))
                                    if "\n" in model.tokenizer.decode([t])
                                ]
                            ),
                            pad_token_id=model.tokenizer.eos_token_id,
                            tokenizer=model.tokenizer,
                            **self.args_generate,
                        )
                    log.info(f"GPU {gpu_id}: Batch {batch_num + 1} generated in {round(time.time() - batch_start_time, 2)} seconds")
                    
                    # Log GPU memory before offloading
                    if torch.cuda.is_available():
                        gpu_mem_before = torch.cuda.memory_allocated(gpu_id) / 1024**3
                        log.info(f"GPU {gpu_id} memory before offloading: {gpu_mem_before:.2f} GB")
                    
                    # Offload generation outputs to CPU
                    log.info(f"GPU {gpu_id}: Offloading generation outputs to CPU to save GPU memory...")
                    out_cpu = self._move_generation_output_to_cpu(out)
                    current_batch_cpu = current_batch.to('cpu')
                    
                    # Delete the original GPU tensors
                    del out, device_batch, current_batch
                    
                    # Log GPU memory after offloading
                    if torch.cuda.is_available():
                        gpu_mem_after = torch.cuda.memory_allocated(gpu_id) / 1024**3
                        log.info(f"GPU {gpu_id} memory after offloading: {gpu_mem_after:.2f} GB (saved: {gpu_mem_before - gpu_mem_after:.2f} GB)")
                    
                    all_outputs.append((out_cpu, current_batch_cpu))
                    
                    # Clear cache after each batch
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                        gc.collect()
                    
                    batch_num += 1
                
                log.info(f"GPU {gpu_id}: Successfully generated with batch size {batch_size} in {round(time.time() - start_time, 2)} seconds")
                generation_succeeded = True
                break
                
            except torch.cuda.OutOfMemoryError as e:
                log.warning(f"GPU {gpu_id}: OOM with batch size {batch_size}: {e}")
                torch.cuda.empty_cache()
                gc.collect()
                
                # Clear any partial outputs
                all_outputs = []
                
                # Try next smaller batch size
                if batch_size == batch_sizes[-1]:
                    # Even smallest batch size failed
                    raise
        
        model.generation_parameters = old_params
        
        if not generation_succeeded:
            raise RuntimeError(f"GPU {gpu_id}: Generation failed with all batch sizes")
        
        # Process each batch and collect results
        all_cut_logits = []
        all_cut_sequences = []
        all_cut_texts = []
        all_cut_alternatives = []
        all_ll = []
        all_input_tokens = []
        all_uncertainty_logits = []
        all_uhead_features = None
        all_modified_batches = []
        
        # Process each batch through the pipeline
        for out_cpu, current_batch_cpu in all_outputs:
            log.info(f"GPU {gpu_id}: Processing batch outputs on CPU...")
            
            # Keep on CPU as in original
            out = out_cpu
            current_batch = current_batch_cpu
            
            # Get basic predictions
            batch_results = self.postprocess_predictions(current_batch, out, model.tokenizer)
            
            # Collect basic results
            batch_cut_logits = batch_results["greedy_log_probs"]
            batch_cut_sequences = batch_results["greedy_tokens"]
            batch_cut_texts = batch_results["greedy_texts"]
            batch_cut_alternatives = batch_results["greedy_tokens_alternatives"]
            batch_ll = batch_results["greedy_log_likelihoods"]
            batch_input_tokens = batch_results["input_tokens"]
            
            all_cut_logits.extend(batch_cut_logits)
            all_cut_sequences.extend(batch_cut_sequences)
            all_cut_texts.extend(batch_cut_texts)
            all_cut_alternatives.extend(batch_cut_alternatives)
            all_ll.extend(batch_ll)
            all_input_tokens.extend(batch_input_tokens)
            
            # Create output bounds and attention mask
            batch_size = current_batch['input_ids'].shape[0]
            output_bounds = []
            full_attn_mask = torch.zeros_like(out.sequences).bool()
            
            for i in range(batch_size):
                idx = current_batch["input_ids"].shape[1]
                full_attn_mask[i, :idx] = current_batch["attention_mask"][i]
                length = len(batch_cut_sequences[i])
                full_attn_mask[i][idx: idx + length] = 1
                output_bounds.append((idx - 1, idx + length - 1))
            
            # Set required attributes
            out["full_attention_mask"] = full_attn_mask
            out["context_lengths"] = torch.tensor([len(it) for it in current_batch["input_ids"]])
            current_batch["context_lenghts"] = out["context_lengths"]
            
            # Process uncertainty head
            if self.predict_token_uncertainties:
                with torch.no_grad():
                    uncertainty_logits = uncertainty_head(current_batch, out)
                    batch_uncertainty_logits = [
                        ue[output_bounds[i][0]: output_bounds[i][1]]
                        for i, ue in enumerate(uncertainty_logits.cpu().detach().squeeze(-1))
                    ]
                    all_uncertainty_logits.extend(batch_uncertainty_logits)
            else:
                batch_features = uncertainty_head.feature_extractor(current_batch, out)
                if all_uhead_features is None:
                    all_uhead_features = batch_features
                else:
                    all_uhead_features = torch.cat([all_uhead_features, batch_features], dim=0)
            
            # Store modified batch
            all_modified_batches.append({k: v.cpu() if torch.is_tensor(v) else v for k, v in current_batch.items()})
            
            # Clear memory
            del out, current_batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                gc.collect()
        
        # Build result dictionary
        result_dict = {
            "input_tokens": all_input_tokens,
            "greedy_log_probs": all_cut_logits,
            "greedy_tokens": all_cut_sequences,
            "greedy_tokens_alternatives": all_cut_alternatives,
            "greedy_texts": all_cut_texts,
            "greedy_log_likelihoods": all_ll,
            "input_texts": texts
        }
        
        if self.predict_token_uncertainties:
            result_dict["uncertainty_logits"] = all_uncertainty_logits
        else:
            result_dict["uhead_features"] = all_uhead_features
            # Handle combined batch
            if all_modified_batches:
                combined_modified_batch = {}
                for key in all_modified_batches[0].keys():
                    if key == "context_lengths":
                        combined_modified_batch[key] = torch.cat([batch[key] for batch in all_modified_batches])
                    else:
                        if key in all_modified_batches[0]:
                            combined_modified_batch[key] = all_modified_batches[0][key]
                result_dict["llm_inputs"] = combined_modified_batch
            else:
                result_dict["llm_inputs"] = batch
            
            # Create full attention mask
            max_seq_len = max(len(tokens) + len(seq) for tokens, seq in zip(all_input_tokens, all_cut_sequences))
            full_attention_mask_all = torch.zeros(len(texts), max_seq_len).bool()
            for i in range(len(texts)):
                input_len = len(all_input_tokens[i])
                full_attention_mask_all[i, :input_len] = 1
                gen_len = len(all_cut_sequences[i])
                full_attention_mask_all[i, input_len:input_len + gen_len] = 1
            result_dict["full_attention_mask"] = full_attention_mask_all
        
        if cache:
            for i in range(len(texts)):
                cache.get(texts[i], lambda: result_dict["greedy_tokens"][i])
        
        return result_dict

    def __call__(self, dependencies: Dict[str, np.array], texts: List[str], model: Model, 
                 max_new_tokens: int = 100, **kwargs) -> Dict[str, np.ndarray]:
        """
        Generate samples using multiple GPUs by distributing prompts across GPUs.
        """
        n_prompts = len(texts)
        log.info(f"MultiGPU generation: {n_prompts} prompts across {len(self.gpu_ids)} GPUs")
        
        # Distribute prompts across GPUs
        prompts_per_gpu = (n_prompts + len(self.gpu_ids) - 1) // len(self.gpu_ids)
        
        # Create prompt assignments for each GPU
        gpu_assignments = []
        for gpu_idx, gpu_id in enumerate(self.gpu_ids):
            start_idx = gpu_idx * prompts_per_gpu
            end_idx = min(start_idx + prompts_per_gpu, n_prompts)
            if start_idx < n_prompts:
                gpu_texts = texts[start_idx:end_idx]
                gpu_assignments.append((gpu_id, gpu_texts, start_idx))
                log.info(f"GPU {gpu_id}: Assigned {len(gpu_texts)} prompts (indices {start_idx}-{end_idx-1})")
        
        # Process on each GPU in parallel
        with ThreadPoolExecutor(max_workers=len(gpu_assignments)) as executor:
            futures = []
            for gpu_id, gpu_texts, start_idx in gpu_assignments:
                gpu_model = self.model_pool.get_model(gpu_id)
                future = executor.submit(
                    self.process_on_single_gpu,
                    gpu_texts, gpu_model, gpu_id, max_new_tokens, **kwargs
                )
                futures.append((future, gpu_id, start_idx, len(gpu_texts)))
            
            # Collect results
            gpu_results = {}
            for future, gpu_id, start_idx, num_texts in futures:
                try:
                    result = future.result()
                    gpu_results[gpu_id] = (result, start_idx, num_texts)
                    log.info(f"GPU {gpu_id}: Completed processing")
                except Exception as e:
                    log.error(f"GPU {gpu_id}: Error during processing: {e}")
                    raise
        
        # Merge results from all GPUs
        # Sort by start_idx to maintain original order
        sorted_results = sorted(gpu_results.values(), key=lambda x: x[1])
        
        # Initialize merged result dict
        merged_result = {
            "input_texts": [],
            "input_tokens": [],
            "greedy_log_probs": [],
            "greedy_tokens": [],
            "greedy_tokens_alternatives": [],
            "greedy_texts": [],
            "greedy_log_likelihoods": [],
        }
        
        if self.predict_token_uncertainties:
            merged_result["uncertainty_logits"] = []
        else:
            all_uhead_features = []
            all_full_attention_masks = []
            all_llm_inputs = []
        
        # Merge results maintaining order
        for gpu_result, start_idx, num_texts in sorted_results:
            merged_result["input_texts"].extend(gpu_result["input_texts"])
            merged_result["input_tokens"].extend(gpu_result["input_tokens"])
            merged_result["greedy_log_probs"].extend(gpu_result["greedy_log_probs"])
            merged_result["greedy_tokens"].extend(gpu_result["greedy_tokens"])
            merged_result["greedy_tokens_alternatives"].extend(gpu_result["greedy_tokens_alternatives"])
            merged_result["greedy_texts"].extend(gpu_result["greedy_texts"])
            merged_result["greedy_log_likelihoods"].extend(gpu_result["greedy_log_likelihoods"])
            
            if self.predict_token_uncertainties:
                merged_result["uncertainty_logits"].extend(gpu_result["uncertainty_logits"])
            else:
                all_uhead_features.append(gpu_result["uhead_features"])
                all_full_attention_masks.append(gpu_result["full_attention_mask"])
                all_llm_inputs.append(gpu_result["llm_inputs"])
        
        # Handle non-token uncertainty case
        if not self.predict_token_uncertainties:
            # Concatenate features
            merged_result["uhead_features"] = torch.cat(all_uhead_features, dim=0)
            
            # Merge llm_inputs
            if all_llm_inputs:
                combined_llm_inputs = {}
                for key in all_llm_inputs[0].keys():
                    if isinstance(all_llm_inputs[0][key], torch.Tensor):
                        combined_llm_inputs[key] = torch.cat([inp[key] for inp in all_llm_inputs], dim=0)
                    else:
                        # For non-tensor values, just take from first (they should be the same)
                        combined_llm_inputs[key] = all_llm_inputs[0][key]
                merged_result["llm_inputs"] = combined_llm_inputs
            
            # Merge attention masks
            merged_result["full_attention_mask"] = torch.cat(all_full_attention_masks, dim=0)
        
        # Handle cache if provided
        cache = None  # Could be passed in kwargs if needed
        if cache:
            for i in range(len(texts)):
                cache.get(texts[i], lambda: merged_result["greedy_tokens"][i])
        
        log.info(f"MultiGPU generation complete: processed {len(merged_result['input_texts'])} prompts")
        return merged_result


def replace_with_multigpu_generation(stat_calculators, model_pool, uhead_path, gpu_ids, 
                                   max_new_tokens=256, temperature=1.0, hf_cache=None):
    """Replace the default sample generation calculator with multi-GPU version."""
    # Load uncertainty heads on each GPU
    uncertainty_heads = {}
    
    for gpu_id in gpu_ids:
        device = f"cuda:{gpu_id}"
        log.info(f"Loading uncertainty head on {device}")
        
        # Get the base model from the pool
        model = model_pool.get_model(gpu_id)
        base_model = model.model  # Get the underlying model
        
        # Load uncertainty head
        uncertainty_head = AutoUncertaintyHead.from_pretrained(
            uhead_path,
            base_model
        )
        uncertainty_head = uncertainty_head.to(device)
        uncertainty_head.eval()
        uncertainty_heads[gpu_id] = uncertainty_head
    
    # Find and replace the sample generation calculator
    new_calculators = []
    found_sample_gen = False
    
    for calc in stat_calculators:
        calc_name = type(calc).__name__
        if "SampleGeneration" in calc_name or calc_name == "sample_generator":
            # Replace with multi-GPU version
            log.info(f"Replacing {calc_name} with MultiGPUSampleGenerationCalculator")
            
            # Get properties from original calculator if possible
            predict_token_unc = getattr(calc, 'predict_token_uncertainties', False)
            
            multi_gpu_calc = MultiGPUSampleGenerationCalculator(
                model_pool=model_pool,
                uncertainty_heads=uncertainty_heads,
                predict_token_uncertainties=predict_token_unc,
                gpu_ids=gpu_ids,
                temperature=temperature,
                args_generate={
                    "max_new_tokens": max_new_tokens,
                    "min_new_tokens": 2,
                    "temperature": temperature,
                    "length_penalty": 1.0,
                    "stop_strings": ["\n\n", "}\n"],
                },
            )
            new_calculators.append(multi_gpu_calc)
            found_sample_gen = True
        else:
            new_calculators.append(calc)
    
    if not found_sample_gen:
        log.warning("No sample generation calculator found to replace")
    
    return new_calculators