import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # to avoid fragmentation
os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true"
import argparse
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import ray # Still needed for the multi-GPU case
import math
import time
import mlflow
from transformers import AutoTokenizer

from disco.distributions.lm_distribution import LMDistribution, TextSample

# This Ray Actor is only used in the multi-GPU case.
@ray.remote(num_gpus=1)
class ScoringActor:
    """
    A Ray Actor that loads a model onto a single GPU and scores pre-tokenized batches.
    """
    def __init__(self, model_name: str, response_length: int):
        """
        Initializes the actor, loading the language model onto its assigned GPU.

        Args:
            model_name: The Hugging Face identifier for the model to load.
        """
        self.device = torch.device("cuda")
        print(f"Actor initializing model '{model_name}' on {self.device}...")
        self.model = LMDistribution(model=model_name, length=response_length, device=self.device)

    def score_batch(self, batch_prompts: list, batch_samples_nested: list) -> np.ndarray:
        """
        Calculates the base model's log probabilities for a pre-tokenized batch.
        This method is designed to run on a GPU actor, focusing only on model inference.

        Args:
            batch_prompts: A list of formatted prompt strings.
            batch_samples_nested: A nested list of TextSample objects, already tokenized and padded.

        Returns:
            A NumPy array containing the log probabilities for the batch.
        """
        if not batch_prompts:
            return np.array([])

        # The batch is already preprocessed; this actor's job is just inference.
        with torch.no_grad():
            batch_log_probs_tensor = self.model.log_score_batch(
                batch_samples_nested, batch_prompts, sum=True
            )

        # The result tensor has shape (batch_size, 1), so we squeeze and move to CPU.
        return batch_log_probs_tensor.squeeze(-1).cpu().numpy()

def process_data(args):
    """
    Main function to load data, process it in batches, and save the result.
    Conditionally uses Ray for multi-GPU processing.
    """
    # 1. Argument validation
    if args.target_type == 'exponential' and args.kl_coef is None:
        raise ValueError("--kl-coef is required when --target-type is 'exponential'")

    # Decide whether to use Ray based on the number of available GPUs.
    use_ray = args.num_gpus > 1

    # 2. Setup and start MLflow Run (moved before conditional logic)
    mlflow.set_experiment(args.mlflow_experiment_name)
    with mlflow.start_run(run_name=args.mlflow_run_name) as run:
        print(f"MLflow run started. Experiment: '{args.mlflow_experiment_name}', Run ID: '{run.info.run_id}'")
        mlflow.log_params(vars(args))

        # 3. Conditional Initialization (Ray actors vs. a single local model)
        if use_ray:
            print("Initializing Ray for multi-GPU processing...")
            ray.init(num_gpus=args.num_gpus)
            available_gpus = ray.available_resources().get("GPU", 0)
            print(f"Ray initialized. Using {available_gpus} GPUs.")
            if available_gpus == 0:
                raise RuntimeError("Ray couldn't find any GPUs. Please check your setup.")

            print(f"Creating {int(available_gpus)} scoring actors...")
            actors = [ScoringActor.remote(args.base_model_name, args.response_length) for _ in range(int(available_gpus))]
        else:
            print("Single GPU detected. Running in local mode without Ray.")
            if not torch.cuda.is_available():
                raise RuntimeError("No GPU available for local mode. Please check your PyTorch installation.")
            device = torch.device("cuda")
            print(f"Initializing model '{args.base_model_name}' on {device}...")
            # In local mode, 'scorer' is the main model object.
            scorer = LMDistribution(model=args.base_model_name, length=args.response_length, device=device)

        # 4. Load tokenizer in the main process (common to both modes)
        print("Loading tokenizer in main process for preprocessing...")
        tokenizer = AutoTokenizer.from_pretrained(args.base_model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # 5. Load dataset (common to both modes)
        print(f"Loading dataset from {args.input_file}...")
        df = pd.read_parquet(args.input_file)
        df['original_index'] = df.index  # Keep track of original rows
        explode_cols = ['responses', 'reward', 'log_probs']
        exploded_df = df.explode(explode_cols).reset_index(drop=True)

        # 6. Prepare for scoring
        print(f"Calculating base model log probabilities...")
        num_batches = math.ceil(len(exploded_df) / args.scoring_batch_size)
        total_samples = len(exploded_df)

        def batch_generator(df, size, tokenizer):
            """
            A generator that yields preprocessed batches.
            Tokenization happens here, on the main CPU process.
            """
            for i in range(0, len(df), size):
                batch_df = df.iloc[i:i + size]
                if batch_df.empty: continue

                batch_prompts = [
                    tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
                    for prompt in batch_df['prompt']
                ]
                batch_responses = batch_df['responses'].tolist()
                tokenized_responses = tokenizer(
                    batch_responses, padding=True, return_tensors="pt", add_special_tokens=False, padding_side='right'
                )
                input_ids_tensor = tokenized_responses['input_ids']
                batch_samples_nested = [
                    [TextSample(token_ids=input_ids_tensor[j], text=batch_responses[j])]
                    for j in range(len(batch_responses))
                ]
                yield batch_prompts, batch_samples_nested

        start_time = time.time()

        # 7. Conditional Scoring Loop
        if use_ray:
            # Multi-GPU path using Ray
            batch_iter = iter(enumerate(batch_generator(exploded_df, args.scoring_batch_size, tokenizer)))
            results_map = {}
            future_to_info = {}
            max_in_flight = len(actors) * 4
            in_flight_futures = []

            with tqdm(total=num_batches, desc="Scoring batches with Ray") as pbar:
                # Prime the queue
                for _ in range(min(max_in_flight, num_batches)):
                    try:
                        i, (batch_prompts, batch_samples_nested) = next(batch_iter)
                        actor_index = i % len(actors)
                        future = actors[actor_index].score_batch.remote(batch_prompts, batch_samples_nested)
                        in_flight_futures.append(future)
                        future_to_info[future] = {"index": i}
                    except StopIteration:
                        break

                # Process as results complete
                while in_flight_futures:
                    ready, in_flight_futures = ray.wait(in_flight_futures)
                    completed_future = ready[0]
                    info = future_to_info.pop(completed_future)
                    results_map[info["index"]] = ray.get(completed_future)
                    pbar.update(1)

                    # Add a new task to the queue
                    try:
                        i, (batch_prompts, batch_samples_nested) = next(batch_iter)
                        actor_index = i % len(actors)
                        future = actors[actor_index].score_batch.remote(batch_prompts, batch_samples_nested)
                        in_flight_futures.append(future)
                        future_to_info[future] = {"index": i}
                    except StopIteration:
                        pass # No more batches to process

            all_base_log_probs_list = [results_map[i] for i in sorted(results_map.keys())]

        else:
            # Single-GPU path without Ray
            all_base_log_probs_list = []
            batch_iter = batch_generator(exploded_df, args.scoring_batch_size, tokenizer)
            with torch.no_grad():
                for batch_prompts, batch_samples_nested in tqdm(batch_iter, total=num_batches, desc="Scoring batches locally"):
                    batch_log_probs_tensor = scorer.log_score_batch(
                        batch_samples_nested, batch_prompts, sum=True
                    )
                    result_array = batch_log_probs_tensor.squeeze(-1).cpu().numpy()
                    all_base_log_probs_list.append(result_array)

        total_duration = time.time() - start_time
        print(f"\nTotal scoring time: {total_duration:.2f} seconds.")
        mlflow.log_metric("total_duration_seconds", total_duration)

        # 8. Gather results (now common to both paths)
        exploded_df['base_log_probs'] = np.concatenate(all_base_log_probs_list)

        # 9. Create proposal_log_probs based on the override flag
        print(f"Creating 'proposal_log_probs' column...")
        if args.override_proposal_probs:
            print("  -> Using 'base_log_probs' as proposal (override active).")
            exploded_df['proposal_log_probs'] = exploded_df['base_log_probs']
        else:
            print("  -> Using 'log_probs' from input as proposal.")
            exploded_df['proposal_log_probs'] = exploded_df['log_probs']

        # 10. Group the results back into the original structure
        agg_functions = { col: 'first' for col in df.columns if col not in explode_cols + ['original_index'] }
        agg_functions.update({ col: list for col in explode_cols + ['base_log_probs', 'proposal_log_probs'] })
        df = exploded_df.groupby('original_index').agg(agg_functions).reset_index(drop=True)

        # 11. Calculate target_log_scores
        print(f"Calculating target log scores using '{args.target_type}' method...")
        if args.target_type == 'pointwise':
            df['target_log_scores'] = df.apply(lambda r: [b + np.log(rw) for b, rw in zip(r['base_log_probs'], r['reward'])], axis=1)
        elif args.target_type == 'exponential':
            df['target_log_scores'] = df.apply(lambda r: [b + (rw / args.kl_coef) for b, rw in zip(r['base_log_probs'], r['reward'])], axis=1)

        # 12. Calculate partition function and normalized target_log_probs
        print("Calculating partition function and normalizing target log probabilities...")
        df['partition_function'] = df.apply(lambda r: np.mean(np.exp(np.array(r['target_log_scores']) - np.array(r['proposal_log_probs']))), axis=1)
        df['target_log_probs'] = df.apply(lambda r: (np.array(r['target_log_scores']) - np.log(r['partition_function'])).tolist(), axis=1)

        # 13. Save the resulting dataframe
        print(f"Saving processed dataframe to {args.output_file}...")
        df.to_parquet(args.output_file, index=False)
        mlflow.log_artifact(args.output_file)

        print("\nProcessing complete! ✨")
        print("Preview of the updated DataFrame:")
        print(df[['prompt', 'reward', 'base_log_probs', 'target_log_scores', 'proposal_log_probs']].head())

    # 14. Conditional Ray Shutdown
    if use_ray:
        ray.shutdown()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Add base and target log probabilities to a dataset using multiple GPUs with Ray and MLflow.")

    # I/O Arguments
    parser.add_argument("--input-file", type=str, required=True, help="Path to the input Parquet dataset.")
    parser.add_argument("--output-file", type=str, required=True, help="Path to save the output Parquet dataset.")

    # Model and Scoring Arguments
    parser.add_argument("--base-model-name", type=str, required=True, help="Hugging Face model identifier for the base model (e.g., 'gpt2').")
    parser.add_argument("--response-length", type=int, default=1024, help="The max length used in generation.")
    parser.add_argument("--target-type", type=str, default='pointwise', choices=['pointwise', 'exponential'], help="The type of target distribution to compute.")
    parser.add_argument("--kl-coef", type=float, default=None, help="KL coefficient, required for 'exponential' target type.")
    parser.add_argument("--override-proposal-probs", action=argparse.BooleanOptionalAction, default=True, help="If set, use base_log_probs as the proposal distribution. Otherwise, use log_probs from the input file.")

    # Performance Arguments
    parser.add_argument("--scoring-batch-size", type=int, default=32, help="Maximum number of SAMPLES (responses) to process at once per actor.")
    # The default for num_gpus will now automatically trigger the correct mode.
    parser.add_argument("--num-gpus", type=int, default=torch.cuda.device_count(), help="Number of GPUs to use for parallel scoring. If 1, runs locally without Ray.")

    # MLflow Arguments
    parser.add_argument("--mlflow-experiment-name", type=str, default="parallel_scoring", help="Name for the MLflow experiment.")
    parser.add_argument("--mlflow-run-name", type=str, default=None, help="Optional name for the MLflow run.")

    args = parser.parse_args()
    process_data(args)