# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time
from pathlib import Path

from datasets import Dataset, load_dataset
from huggingface_hub import (
    create_branch,
    list_repo_commits,
    repo_exists,
)

from sal.config import Config

logger = logging.getLogger()


def get_dataset(config: Config) -> Dataset:
    dataset = load_dataset(config.dataset_name, split=config.dataset_split)

        # Normalize different benchmarks to a common schema
    # Expected columns downstream: problem (str), answer (str), optional solution (str), optional unique_id
    if config.dataset_name.replace(" ", "") in {"Maxwell-Jia/AIME_2024", "Maxwell-Jia/AIME_2024"}:
        old_cols = dataset.column_names
        # Build standardized columns from original (case-sensitive) fields
        def _to_standard(x):
            # Fall back across both original-case and lower-case keys to be robust
            get = lambda *keys: next((x[k] for k in keys if k in x), None)
            problem = get("Problem", "problem")
            answer = get("Answer", "answer")
            solution = get("Solution", "solution")
            unique_id = get("ID", "id", "unique_id")
            return {
                "problem": problem,
                "answer": "" if answer is None else str(answer),
                "solution": solution if solution is not None else "",
                "unique_id": unique_id,
            }

        dataset = dataset.map(_to_standard, remove_columns=old_cols)
        # Ensure unique_id exists
        if "unique_id" not in dataset.column_names or any(
            v is None for v in dataset["unique_id"]
        ):
            dataset = dataset.map(
                lambda _, idx: {"unique_id": idx}, with_indices=True
            )

    if config.dataset_start is not None and config.dataset_end is not None:
        dataset = dataset.select(range(config.dataset_start, config.dataset_end))
    if config.num_samples is not None:
        dataset = dataset.select(range(min(len(dataset), config.num_samples)))

    return dataset


def save_dataset(dataset, config):
    if config.push_to_hub:
        # Since concurrent pushes can get rejected by the Hub, we make several attempts to push the dataset with try/except
        for _ in range(20):
            try:
                # Create branch from the repo's initial commit.
                # This is needed to avoid branching from a commit on main that already has data
                if repo_exists(config.hub_dataset_id, repo_type="dataset"):
                    initial_commit = list_repo_commits(
                        config.hub_dataset_id, repo_type="dataset"
                    )[-1]
                    create_branch(
                        repo_id=config.hub_dataset_id,
                        branch=config.revision,
                        revision=initial_commit.commit_id,
                        exist_ok=True,
                        repo_type="dataset",
                    )
                url = dataset.push_to_hub(
                    config.hub_dataset_id,
                    revision=config.revision,
                    split="train",
                    private=config.hub_dataset_private,
                    commit_message=f"Add {config.revision}",
                )
                break
            except Exception as e:
                logger.error(f"Error pushing dataset to the Hub: {e}")
                time.sleep(5)
        logger.info(f"Pushed dataset to {url}")
    else:
        if hasattr(config, 'calib_output_path') and config.calib_output_path is not None:
            output_path = config.calib_output_path
            filename = "calibration_completions.jsonl"
        else:
            if config.output_dir is None:
                # config.output_dir = f"data/{config.model_path}"
                config.output_dir = f"data/n_{config.n}/{config.model_path.rstrip('/').split('/')[-1]}"
                output_path = config.output_dir
                filename = f"{config.approach}_completions.jsonl"
        
        Path(output_path).mkdir(parents=True, exist_ok=True)
        dataset.to_json(
            f"{output_path}/{filename}", lines=True
        )
        logger.info(
            f"Saved completions to {output_path}/{filename}"
        )
