import json
import argparse
import logging
import os
import math
import random
from datasets import load_dataset, Dataset
from typing import Optional, Dict, Union, List, Iterable
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TrainerCallback
from trl import SFTTrainer, SFTConfig
import copy
import torch
from torch.utils.data import IterableDataset


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class CurriculumIterableDataset(IterableDataset):
    """An iterable dataset that yields examples staged by difficulty.

    It yields examples in stage order. For each stage we compute how many
    examples are needed (based on max_steps, gradient accumulation and batch
    size) and then stream examples from that stage's pool (shuffling and
    repeating if needed).
    """

    def __init__(
        self,
        hf_dataset: Dataset,
        indices_by_stage: List[List[int]],
        examples_needed_by_stage: List[int],
        seed: int = 42,
    ):
        self.hf_dataset = hf_dataset
        self.indices_by_stage = indices_by_stage
        self.examples_needed_by_stage = examples_needed_by_stage
        self.seed = seed

    def get_epoch_iter(self, indices: List[int]) -> Iterable[int]:
        # Return an infinite iterator cycling through shuffled indices
        rng = random.Random(self.seed)
        pool = list(indices)
        while True:
            rng.shuffle(pool)
            for idx in pool:
                yield idx

    def __iter__(self):
        # For each stage, yield the required number of examples
        for stage_idx, (indices, needed) in enumerate(zip(self.indices_by_stage, self.examples_needed_by_stage)):
            if len(indices) == 0:
                continue
            iter_idx = self.get_epoch_iter(indices)
            for _ in range(needed):
                idx = next(iter_idx)
                yield self.hf_dataset[int(idx)]


class CurriculumCheckpointCallback(TrainerCallback):
    """Save model/tokenizer at predefined global step boundaries (stage ends)."""

    def __init__(self, stage_boundaries: List[int], base_output_dir: str, tokenizer, logger=logging):
        super().__init__()
        self.stage_boundaries = set(stage_boundaries)
        self.base_output_dir = base_output_dir
        self.tokenizer = tokenizer
        self.logger = logger
        self._saved = set()

    def on_step_end(self, args, state, control, **kwargs):
        trainer = kwargs.get("trainer")
        if trainer is None:
            return
        global_step = int(state.global_step)
        # Save only once per boundary
        if global_step in self.stage_boundaries and global_step not in self._saved:
            # determine which stage this corresponds to by ordering boundaries
            ordered = sorted(list(self.stage_boundaries))
            stage_idx = ordered.index(global_step)
            out_dir = os.path.join(self.base_output_dir, f"stage_{stage_idx}")
            self.logger.info(f"Checkpointing at global step {global_step} -> {out_dir}")
            trainer.save_model(out_dir)
            # tokenizer may be available as trainer.processing_class or passed directly
            try:
                if hasattr(trainer, "processing_class") and trainer.processing_class is not None:
                    trainer.processing_class.save_pretrained(out_dir)
                elif self.tokenizer is not None:
                    self.tokenizer.save_pretrained(out_dir)
            except Exception:
                self.logger.exception("Failed to save tokenizer at stage checkpoint")
            self._saved.add(global_step)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='path to the json config file')
    args = parser.parse_args()
    config = json.load(open(args.config))

    # --- Load full dataset once ---
    logging.info(f"Loading full dataset from {config['dataset']['trainset_path']} and {config['dataset']['valset_path']}")
    full_dataset = load_dataset("json", data_files={
        "train": config["dataset"]["trainset_path"],
        "validation": config["dataset"]["valset_path"],
    })

    # --- Curriculum Learning Setup ---
    stages = config["curriculum"]["stages"]
    num_stages = len(stages)
    initial_model_path = config["models"]["student"]
    base_output_dir = config["training"]["output_dir"]
    seed = config["dataset"].get("seed", 42)

    logging.info(f"Found {num_stages} stages in the curriculum.")

    # Training hyperparams needed to compute how many examples to stream per stage
    max_steps = int(config["training"].get("max_steps", 0))
    per_device_train_batch_size = int(config["training"].get("per_device_train_batch_size", 1))
    grad_accum = int(config["training"].get("gradient_accumulation_steps", 1))

    if max_steps <= 0:
        raise ValueError("max_steps must be > 0 in training config")

    # Distribute optimizer steps (global steps) across stages as evenly as possible
    base_steps = max_steps // num_stages
    remainder = max_steps % num_stages
    steps_per_stage = [base_steps + (1 if i < remainder else 0) for i in range(num_stages)]

    # For each stage we need batches = steps_per_stage * grad_accum
    batches_per_stage = [s * grad_accum for s in steps_per_stage]
    # examples needed = batches * batch_size
    examples_needed_by_stage = [b * per_device_train_batch_size for b in batches_per_stage]

    # Build list of indices per stage from the HF dataset
    indices_by_stage = []
    for i, difficulty_levels in enumerate(stages):
        logging.info(f"Stage {i}: Filtering dataset for difficulties: {difficulty_levels}")
        # Collect indices matching any difficulty in this stage
        idxs = [idx for idx, ex in enumerate(full_dataset['train']) if ex.get('difficulty') in difficulty_levels]
        logging.info(f"Stage {i}: Found {len(idxs)} examples for difficulties {difficulty_levels}")
        indices_by_stage.append(idxs)

    # Warn if any stage has no data
    for i, idxs in enumerate(indices_by_stage):
        if len(idxs) == 0 and examples_needed_by_stage[i] > 0:
            logging.warning(f"Stage {i} has no data but requires {examples_needed_by_stage[i]} examples; it will be skipped in streaming.")

    # Create Curriculum IterableDataset
    curriculum_dataset = CurriculumIterableDataset(
        hf_dataset=full_dataset['train'],
        indices_by_stage=indices_by_stage,
        examples_needed_by_stage=examples_needed_by_stage,
        seed=seed,
    )

    # --- Load model & tokenizer once ---
    logging.info(f"Loading initial model/tokenizer from {initial_model_path}")
    student_tokenizer = AutoTokenizer.from_pretrained(initial_model_path, trust_remote_code=True)
    if student_tokenizer.pad_token is None:
        student_tokenizer.pad_token = student_tokenizer.eos_token

    student_model = AutoModelForCausalLM.from_pretrained(initial_model_path, trust_remote_code=True)

    # --- Setup Trainer once ---
    training_arguments = SFTConfig(**config["training"])
    job_type = config.get("job_type", "naive_cl_local")
    if "naive_cl" not in job_type:
        raise ValueError(f"Invalid job type: {job_type}")

    trainer = SFTTrainer(
        model=student_model,
        processing_class=student_tokenizer,
        args=training_arguments,
        train_dataset=curriculum_dataset,
    )

    # Prepare stage boundaries (global steps at which to checkpoint)
    cumulative = 0
    stage_boundaries = []
    for s in steps_per_stage:
        cumulative += s
        stage_boundaries.append(cumulative)

    # Register callback to checkpoint at the end of each stage
    callback = CurriculumCheckpointCallback(stage_boundaries=stage_boundaries, base_output_dir=base_output_dir, tokenizer=student_tokenizer)
    trainer.add_callback(callback)

    # Single training call
    logging.info("Starting single-pass curriculum training with one SFTTrainer.train() call")
    trainer.train()

    # Final save
    final_model_path = os.path.join(base_output_dir, f"stage_{num_stages-1}")
    logging.info(f"Saving final model to {final_model_path}")
    trainer.save_model(final_model_path)
    student_tokenizer.save_pretrained(final_model_path)
    logging.info("All curriculum learning stages are complete.")


if __name__ == "__main__":
    main()