import asyncio
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

from datasets import Dataset, concatenate_datasets
from safetytooling.apis.finetuning.openai.run import OpenAIFTConfig, main


def _convert_to_messages(prompt, completion, system_prompt=None):
    """
    Helper function to convert prompt/completion format to messages format.

    Args:
        prompt: The user prompt
        completion: The assistant completion
        system_prompt: Optional system prompt

    Returns:
        List of message dictionaries in chat format
    """
    messages = []

    # Add system message if present
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    # Add user and assistant messages
    messages.append({"role": "user", "content": prompt})
    messages.append({"role": "assistant", "content": completion})

    return messages


def mix_datasets(
    datasets: List[Dataset],
    proportions: List[float],
    system_prompts: Optional[List[str]] = None,
    total_size: Optional[int] = None,
    seed: int = 42,
    sample_strategy: str = "percentage",
) -> Dataset:
    """
    Mix multiple Hugging Face datasets according to specified proportions and convert
    to a unified format with a single "messages" column for fine-tuning.

    Args:
        datasets: List of Hugging Face datasets to mix
        proportions: List of proportions for each dataset (will be normalized to sum to 1)
        system_prompts: Optional; list of system prompts to apply to each dataset. If provided,
            must have the same length as datasets.
        total_size: Optional; total size of the resulting dataset. If None, will use all available data
                    respecting the proportions
        seed: Random seed for shuffling
        sample_strategy: Either "percentage" (use proportions as percentages) or
                        "absolute" (use proportions as absolute counts)

    Returns:
        A mixed dataset with a single "messages" column containing chat format messages

    Examples:
        # Mix 70% of dataset_a and 30% of dataset_b
        mixed = mix_datasets([dataset_a, dataset_b], [0.7, 0.3])

        # Create dataset with exactly 50,000 examples with 60/40 split
        mixed = mix_datasets([dataset_a, dataset_b], [0.6, 0.4], total_size=50000)

        # Mix datasets with specific counts: 10000 from A, 5000 from B, 2000 from C
        mixed = mix_datasets([dataset_a, dataset_b, dataset_c],
                            [10000, 5000, 2000],
                            sample_strategy="absolute")

        # Add system prompts per dataset
        mixed = mix_datasets([dataset_a, dataset_b], [0.7, 0.3],
                            system_prompts=["System prompt for dataset A", "System prompt for dataset B"])
    """
    # Input validation
    if len(datasets) != len(proportions):
        raise ValueError("Length of datasets and proportions must match")

    if len(datasets) == 0:
        raise ValueError("At least one dataset must be provided")

    for prop in proportions:
        if prop != "all" and prop < 0:
            raise ValueError("Proportions must be non-negative")

    # Validate system_prompts if provided
    if system_prompts is not None:
        if len(system_prompts) != len(datasets):
            raise ValueError("Length of system_prompts must match length of datasets")

    # Check for valid columns in all datasets
    valid_columns = {"prompt", "completion", "system_prompt", "messages"}
    for i, dataset in enumerate(datasets):
        # First check if this dataset already has the new format
        if "messages" in dataset.column_names:
            if len(set(dataset.column_names) - {"messages"}) > 0:
                raise ValueError(
                    f"Dataset at index {i} contains 'messages' column along with other columns. "
                    f"If using 'messages' format, it should be the only column."
                )
            continue

        # Otherwise check if it has the old format columns
        missing_required = {"prompt", "completion"} - set(dataset.column_names)
        if missing_required:
            raise ValueError(
                f"Dataset at index {i} is missing required columns: {missing_required}. "
                f"Either 'messages' column or both 'prompt' and 'completion' columns are required."
            )

        unexpected_cols = [col for col in dataset.column_names if col not in valid_columns]
        if unexpected_cols:
            raise ValueError(
                f"Dataset at index {i} contains unexpected columns: {unexpected_cols}. "
                f"Only {', '.join(valid_columns)} are allowed."
            )

    # Prepare selected datasets
    selected_datasets = []

    if sample_strategy == "percentage":
        assert sum(proportions) == 1, "Proportions must sum to 1"
        if total_size is not None:
            # Calculate number of examples to take from each dataset
            counts = [int(p * total_size) for p in proportions]

            # Adjust for rounding errors to ensure we get exactly total_size
            diff = total_size - sum(counts)
            if diff > 0:
                # Add remaining samples to largest proportion dataset
                largest_idx = proportions.index(max(proportions))
                counts[largest_idx] += diff
        else:
            # Calculate size based on smallest dataset's proportion
            min_prop = min([p for p in proportions if p > 0])
            min_prop_idx = proportions.index(min_prop)
            min_dataset_size = len(datasets[min_prop_idx])
            effective_size = int(min_dataset_size / min_prop)

            counts = [int(p * effective_size) for p in proportions]

        # Shuffle and select from each dataset
        for i, (dataset, count) in enumerate(zip(datasets, counts)):
            if count > 0:
                # Make sure we don't request more than available
                count = min(count, len(dataset))
                dataset_selection = dataset.shuffle(seed=seed + i).select(range(count))

                # Convert to messages format if not already
                if "messages" not in dataset_selection.column_names:
                    # Get system prompt if specified
                    system_prompt = None
                    if system_prompts is not None:
                        system_prompt = system_prompts[i]

                    # Convert to messages format
                    dataset_selection = dataset_selection.map(
                        lambda x: {
                            "messages": _convert_to_messages(
                                x.get("prompt", ""), x.get("completion", ""), x.get("system_prompt", system_prompt)
                            )
                        },
                        remove_columns=dataset_selection.column_names,
                    )

                selected_datasets.append(dataset_selection)

    elif sample_strategy == "absolute":
        # Use the absolute counts directly
        for i, (dataset, count) in enumerate(zip(datasets, proportions)):
            if count == "all":
                count = len(dataset)
            if count > 0:
                # Make sure we don't request more than available
                count = min(int(count), len(dataset))
                dataset_selection = dataset.shuffle(seed=seed + i).select(range(count))

                # Convert to messages format if not already
                if "messages" not in dataset_selection.column_names:
                    # Get system prompt if specified
                    system_prompt = None
                    if system_prompts is not None:
                        system_prompt = system_prompts[i]

                    # Convert to messages format
                    dataset_selection = dataset_selection.map(
                        lambda x: {
                            "messages": _convert_to_messages(
                                x.get("prompt", ""), x.get("completion", ""), x.get("system_prompt", system_prompt)
                            )
                        },
                        remove_columns=dataset_selection.column_names,
                    )

                selected_datasets.append(dataset_selection)

    # Concatenate and shuffle the final dataset
    if not selected_datasets:
        raise ValueError("No data selected from any dataset. Check your proportions and dataset sizes.")

    final_dataset = concatenate_datasets(selected_datasets).shuffle(seed=seed)

    # Validate the final dataset has only the messages column
    if set(final_dataset.column_names) != {"messages"}:
        raise ValueError(
            f"Final mixed dataset should contain only the 'messages' column but has: {final_dataset.column_names}"
        )

    return final_dataset


async def finetune_4o_on_dataset(
    dataset: Dataset | List[List[Dict[str, Any]]],
    model_name: str = "gpt-4o-2024-08-06",
    output_dir: str = "finetuned_4o",
    max_retries: int = 1,
    wandb_project: str = "parallel-safety-finetuning",
    batch_size: int = 32,
    n_epochs: int = 1,
    tags: list = None,
    dry_run: bool = False,
) -> Dict[str, Any]:
    """
    Fine-tune GPT-4o on a provided HuggingFace dataset with messages format.

    Args:
        dataset: HuggingFace Dataset with a 'messages' column
        model_name: Base model name to fine-tune
        output_dir: Directory to save outputs
        max_retries: Maximum retry attempts if training fails
        wandb_project: Weights & Biases project name
        batch_size: Training batch size
        n_epochs: Number of training epochs
        tags: Tags for tracking experiments
        dry_run: If True, doesn't actually submit fine-tuning job

    Returns:
        Dictionary with fine-tuning results including the model ID and cost
    """
    # Set up logging
    logger = logging.getLogger("finetune-4o")
    if not logger.handlers:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
        logger.setLevel(logging.INFO)
        logger.addHandler(handler)
        logger.propagate = False

    # Validate dataset columns
    if isinstance(dataset, Dataset):
        if "messages" not in dataset.column_names:
            raise ValueError("Dataset must contain 'messages' column")

        if len(dataset.column_names) > 1:
            raise ValueError(f"Dataset should only contain 'messages' column, but found: {dataset.column_names}")
    elif isinstance(dataset, list):
        dataset = [{"messages": item} for item in dataset]

    # Save dataset statistics
    stats = {
        "num_examples": len(dataset),
        "columns": dataset.column_names if isinstance(dataset, Dataset) else ["messages"],
    }

    # Calculate message statistics
    total_user_messages = 0
    total_assistant_messages = 0
    total_system_messages = 0
    total_user_tokens = 0
    total_assistant_tokens = 0
    total_system_tokens = 0

    for item in dataset:
        for message in item["messages"]:
            role = message.get("role", "unknown")
            content_length = len(message.get("content", ""))

            if role == "user":
                total_user_messages += 1
                total_user_tokens += content_length // 4  # rough token estimate
            elif role == "assistant":
                total_assistant_messages += 1
                total_assistant_tokens += content_length // 4  # rough token estimate
            elif role == "system":
                total_system_messages += 1
                total_system_tokens += content_length // 4  # rough token estimate

    stats.update(
        {
            "user_messages": total_user_messages,
            "assistant_messages": total_assistant_messages,
            "system_messages": total_system_messages,
            "avg_user_tokens": total_user_tokens / total_user_messages if total_user_messages > 0 else 0,
            "avg_assistant_tokens": (
                total_assistant_tokens / total_assistant_messages if total_assistant_messages > 0 else 0
            ),
            "avg_system_tokens": total_system_tokens / total_system_messages if total_system_messages > 0 else 0,
            "pct_with_system": (total_system_messages / len(dataset)) * 100,
        }
    )

    # Save dataset statistics
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = Path(output_dir) / model_name / timestamp

    os.makedirs(run_dir / "stats", exist_ok=True)
    with open(run_dir / "stats" / "dataset_stats.json", "w") as f:
        json.dump(stats, f, indent=2)

    # Create output directory with training_runs structure
    data_dir = run_dir / "data"
    os.makedirs(data_dir, exist_ok=True)

    # Format data for fine-tuning (already in correct format)
    logger.info(f"Processing dataset with {len(dataset)} examples")

    # Save as JSONL - each example is already in the required format with "messages" key
    train_file = data_dir / "train.jsonl"
    with open(train_file, "w") as f:
        for item in dataset:
            # Each item already contains a "messages" field
            f.write(json.dumps({"messages": item["messages"]}) + "\n")

    logger.info(f"Saved {len(dataset)} examples to {train_file}")

    # Save training details
    training_details = {
        "base_model": model_name,
        "batch_size": batch_size,
        "n_epochs": n_epochs,
        "dataset_size": len(dataset),
        "wandb_project": wandb_project,
        "tags": tags or [],
        "timestamp": datetime.now().isoformat(),
        "columns": dataset.column_names if isinstance(dataset, Dataset) else ["messages"],
    }

    with open(run_dir / "training_details.json", "w") as f:
        json.dump(training_details, f, indent=2)

    logger.info(f"Saved training details to {run_dir / 'training_details.json'}")

    # Configure fine-tuning only once
    config = OpenAIFTConfig(
        train_file=train_file,
        val_file=None,
        model=model_name,
        n_epochs=n_epochs,
        wandb_project_name=wandb_project,
        tags=tags or [],
        dry_run=dry_run,
        batch_size=batch_size,
    )

    # Track job_id to avoid creating multiple jobs
    ft_job = None

    # Fine-tune with retries
    for attempt in range(max_retries):
        try:
            logger.info(f"Starting fine-tuning attempt {attempt + 1}/{max_retries}")
            ft_job, _ = await main(config, verbose=True)
            logger.info(f"Finished training model {ft_job.fine_tuned_model}")
            break
        except Exception as e:
            logger.error(f"Attempt {attempt + 1} failed: {str(e)}")
            await asyncio.sleep(30)
            continue

    # Check if fine-tuning job was successful
    if ft_job is None:
        raise ValueError("Fine-tuning job failed - returned None")

    # Save job ID and other information
    job_info = {
        "fine_tuned_model": ft_job.fine_tuned_model,
        "base_model": model_name,
        "cost_estimate": ft_job.cost_estimate if hasattr(ft_job, "cost_estimate") else None,
        "n_examples": len(dataset),
        "training_file": str(train_file),
        "batch_size": batch_size,
        "n_epochs": n_epochs,
        "start_time": training_details["timestamp"],
    }

    # Save the job information
    with open(run_dir / "ft_job_details.json", "w") as f:
        json.dump(vars(ft_job), f, indent=2, default=str)

    # Save initial job info
    with open(run_dir / "job_info.json", "w") as f:
        json.dump(job_info, f, indent=2)

    logger.info("Fine-tuning job submitted successfully!")
    logger.info(f"Finetuned model: {ft_job.fine_tuned_model}")
    logger.info(f"Cost estimate: ${ft_job.cost_estimate if hasattr(ft_job, 'cost_estimate') else 'unknown'}")

    return job_info
