import json
import argparse
import logging
import os
from typing import Dict, List, Optional
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
import torch
from torch.utils.data import Sampler
from transformers import TrainerCallback
from peft import LoraConfig, get_peft_model
import random
import wandb

from common import build_validation_from_cfg

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


# ========== Utilities ==========

def even_split(total: int, parts: int) -> List[int]:
    """Evenly split an integer total into `parts` buckets, distributing the remainder to the first buckets."""
    base = total // parts
    rem = total % parts
    return [base + (1 if i < rem else 0) for i in range(parts)]

def set_seed_all(seed: int):
    if seed is None:
        return
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# ========== Soft-switching Sampler ==========

class SoftSwitchCurriculumSampler(Sampler):
    """
    Soft-switching curriculum sampler.

    Modes:
      - strict_mix=True: EXACT PER-STEP MIXING. Each yielded 'chunk' (~one optimizer step worth of samples)
        contains ≈ major_prob from current stage and the remainder split uniformly across previous stages.
      - strict_mix=False: PROBABILISTIC PER-SAMPLE MIXING. Each sample independently chooses current vs previous
        with major_prob / (1-major_prob).

    Notes:
      - Previous stages are provided by stage_state["prev"], set by the callback.
      - Empty previous stages are skipped with probabilities renormalized.
      - If no previous stages exist yet, sampling is 100% from current.
    """
    def __init__(
        self,
        stage_indices: Dict[int, List[int]],
        stage_state: Dict[str, int],
        batch_size: int,
        world_size: int,
        chunk_steps: int = 1,
        seed: int = 42,
        major_prob: float = 0.8,
        strict_mix: bool = False,   # <---- NEW
    ):
        self.stage_indices = {k: list(v) for k, v in stage_indices.items()}
        self.stage_state = stage_state
        self.batch_size = batch_size
        self.world_size = world_size
        self.chunk_steps = max(1, int(chunk_steps))
        self.major_prob = max(0.0, min(1.0, float(major_prob)))
        self.strict_mix = bool(strict_mix)  # <---- NEW

        self.seed = seed
        self.rng = random.Random(seed)

        # Shuffle per stage deterministically
        for lst in self.stage_indices.values():
            self.rng.shuffle(lst)

        # Round-robin cursors per stage
        self.cursors = {k: 0 for k in self.stage_indices.keys()}

        self._nominal_len = sum(len(v) for v in self.stage_indices.values())

    def _next_from_stage(self, stage_id: int) -> int:
        pool = self.stage_indices.get(stage_id, [])
        if not pool:
            raise RuntimeError(
                f"Sampler asked to draw from empty stage {stage_id}. "
                "Check bucket construction / schedule."
            )
        cur = self.cursors[stage_id]
        idx = pool[cur % len(pool)]
        self.cursors[stage_id] = cur + 1
        return idx

    def _choose_stage_probabilistic(self, current: int, prev_list: List[int]) -> int:
        nonempty_prev = [s for s in prev_list if len(self.stage_indices.get(s, [])) > 0]
        if not nonempty_prev or self.major_prob >= 1.0:
            return current

        current_has = len(self.stage_indices.get(current, [])) > 0
        if not current_has:
            return self.rng.choice(nonempty_prev)

        stages = [current] + nonempty_prev
        p_current = self.major_prob
        p_prev_each = (1.0 - p_current) / len(nonempty_prev)
        probs = [p_current] + [p_prev_each] * len(nonempty_prev)

        r = self.rng.random()
        cum = 0.0
        for s, p in zip(stages, probs):
            cum += p
            if r <= cum:
                return s
        return stages[-1]

    def __iter__(self):
        while True:
            current_stage = self.stage_state["current"]
            prev_stages = self.stage_state.get("prev", [])

            nonempty_prev = [s for s in prev_stages if len(self.stage_indices.get(s, [])) > 0]
            current_has = len(self.stage_indices.get(current_stage, [])) > 0
            if not current_has and not nonempty_prev:
                raise RuntimeError(
                    f"No available samples: stage {current_stage} empty and no non-empty previous stages."
                )

            chunk_size = self.batch_size * self.world_size * self.chunk_steps

            if not self.strict_mix:
                # ---- PROBABILISTIC PER-SAMPLE MIXING ----
                out = []
                for _ in range(chunk_size):
                    s = self._choose_stage_probabilistic(current_stage, prev_stages)
                    out.append(self._next_from_stage(s))
                for i in out:
                    yield i
                continue

            # ---- EXACT PER-STEP MIXING (strict_mix=True) ----
            if not nonempty_prev or self.major_prob >= 1.0 or not current_has:
                # Degenerate: draw all from current if possible; otherwise from prev only
                stages_plan = [current_stage] * chunk_size if current_has else []
                if not stages_plan:  # current empty, pull from prev only
                    for i in range(chunk_size):
                        stages_plan.append(nonempty_prev[i % len(nonempty_prev)])
            else:
                n_cur = int(round(self.major_prob * chunk_size))
                n_prev_total = max(0, chunk_size - n_cur)
                k = len(nonempty_prev)
                base = n_prev_total // k
                rem  = n_prev_total % k
                per_prev = [base + (1 if i < rem else 0) for i in range(k)]
                stages_plan = [current_stage] * n_cur
                for s, cnt in zip(nonempty_prev, per_prev):
                    stages_plan.extend([s] * cnt)
                # De-block so batches aren’t grouped by stage
                self.rng.shuffle(stages_plan)

            out = [self._next_from_stage(s) for s in stages_plan]
            for i in out:
                yield i

    def __len__(self):
        return self._nominal_len


# ========== Callback that schedules stages by steps (and exposes prev set) ==========

class StagedCurriculumCallback(TrainerCallback):
    """
    Switches stages on a fixed step schedule:
      - Evenly splits `training.max_steps` across the number of stages
      - On step boundaries, updates `stage_state['current']`
      - Also sets `stage_state['prev']` to all stages *before* the current one
        according to the traversal `stage_order`
      - Traversal order controlled by `order` in {"normal","reverse","random"}.
        * normal: [0,1,2,...,K-1]
        * reverse: [K-1,...,2,1,0]
        * random: a fixed random permutation for the whole run (seeded)
    """
    def __init__(
        self,
        stage_state: Dict[str, int],
        num_stages: int,
        max_steps: int,
        output_dir: str,
        save_each_stage: bool = False,
        order: str = "normal",
        seed: Optional[int] = None,
    ):
        self.stage_state = stage_state
        self.num_stages = num_stages
        self.max_steps = max_steps
        self.output_dir = output_dir
        self.save_each_stage = save_each_stage

        order = (order or "normal").lower()
        if order not in {"normal", "reverse", "random"}:
            logging.warning(f"[Curriculum] Unknown order='{order}', defaulting to 'normal'.")
            order = "normal"
        self.order = order
        self.seed = seed

        self.steps_per_stage = even_split(max_steps, num_stages)
        self.boundaries = []
        s = 0
        for n in self.steps_per_stage:
            s += n
            self.boundaries.append(s)  # cumulative end steps

        # Determine traversal order
        self.stage_order = list(range(self.num_stages))
        if self.order == "reverse":
            self.stage_order.reverse()
        elif self.order == "random":
            rng = random.Random(self.seed)
            rng.shuffle(self.stage_order)

        self._current_stage = -1  # internal tracker to print logs once

        # Expose the order to the sampler (handy for debugging)
        self.stage_state["order"] = list(self.stage_order)

        logging.info(
            f"Scheduling (steps): per-stage = {self.steps_per_stage}, "
            f"boundaries = {self.boundaries}, order={self.order}, "
            f"order_seq={self.stage_order}, seed={self.seed}"
        )

    def _stage_for_step(self, global_step: int) -> int:
        cum = 0
        for i, n in enumerate(self.steps_per_stage):
            end = cum + n
            if global_step < end:
                return i
            cum = end
        return self.num_stages - 1

    def _map_stage(self, base_index: int) -> int:
        return self.stage_order[base_index]

    def _compute_prev_list(self, current_stage: int) -> List[int]:
        """All stages that come *before* current_stage in stage_order."""
        try:
            pos = self.stage_order.index(current_stage)
        except ValueError:
            return []
        return self.stage_order[:pos]

    def on_train_begin(self, args, state, control, **kwargs):
        first_stage = self._map_stage(0)
        self._set_stage(first_stage, state)
        return control
    
    def on_step_begin(self, args, state, control, **kwargs):
        base_idx = self._stage_for_step(state.global_step)
        stage_now = self._map_stage(base_idx)
        if stage_now != self._current_stage:
            self._switch(stage_now, state, **kwargs)
        return control

    def _set_stage(self, stage: int, state):
        self.stage_state["current"] = stage
        self.stage_state["prev"] = self._compute_prev_list(stage)
        self._current_stage = stage
        if state.is_world_process_zero:
            logging.info(f"[Curriculum] Using stage {stage} | prev={self.stage_state['prev']}")

    def _switch(self, new_stage: int, state, **kwargs):
        old = self._current_stage
        self.stage_state["current"] = new_stage
        self.stage_state["prev"] = self._compute_prev_list(new_stage)
        self._current_stage = new_stage
        if state.is_world_process_zero:
            logging.info(f"[Curriculum] Switching stage {old} -> {new_stage} | prev={self.stage_state['prev']}")
            if self.save_each_stage and old >= 0:
                model = kwargs.get("model", None)
                if model is not None:
                    stage_dir = os.path.join(self.output_dir, f"stage_{old}")
                    os.makedirs(stage_dir, exist_ok=True)
                    logging.info(f"[Curriculum] Saving snapshot for stage {old} to {stage_dir}")
                    try:
                        model.save_pretrained(stage_dir)
                    except Exception as e:
                        logging.warning(f"[Curriculum] Could not save snapshot for stage {old}: {e}")


# ---- Per-bucket evaluation & logging ----
class StagedEvalLoggerCallback(TrainerCallback):
    """
    Every `eval_steps`, sample up to `eval_questions` from each bucket (stage),
    run generation, compute accuracy, and log to wandb:
      - curriculum/accuracy_bucket_{i}
      - curriculum/num_eval_bucket_{i}
      - curriculum/current_stage
    This does NOT change sampling weights or stages; it only logs.
    """
    def __init__(
        self,
        model,
        tokenizer,
        val_bucketed_indices: Dict[int, List[int]],
        val_dataset,  # HuggingFace Dataset
        eval_steps: int = 10,
        eval_questions: int = 50,
        max_new_tokens: int = 1024,
        gen_eval: Optional[object] = None,
        verifier_type: str = "math"
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.val_bucketed_indices = val_bucketed_indices
        self.val_dataset = val_dataset
        self.eval_steps = max(1, int(eval_steps))
        self.eval_questions = max(1, int(eval_questions))
        self.max_new_tokens = max_new_tokens
        self.gen_eval = gen_eval
        self.verifier_type = verifier_type

    @torch.no_grad()
    def _run_custom_validation(self) -> Dict[int, float]:
        if self.gen_eval is None:
            raise ValueError("GenerationEvaluator (gen_eval) is not provided for validation.")
        
        logging.info("Running validation...")
        self.model.eval()
        accuracies = {}

        try:
            for bucket_id, indices in self.val_bucketed_indices.items():
                if not indices:
                    accuracies[bucket_id] = 0.0
                    continue

                sample_indices = random.sample(indices, min(self.eval_questions, len(indices)))
                acc = self.gen_eval.evaluate_indices(
                    dataset = self.val_dataset,
                    indices = sample_indices,
                    verify_mode = self.verifier_type,
                )
                accuracies[bucket_id] = acc
                logging.info(f"  Bucket {bucket_id}: {len(sample_indices)} samples, accuracy = {acc:.2f}%")
        finally:
            torch.cuda.empty_cache()
            logging.debug("Cleared CUDA cache after validation.")
            self.model.train()
        return accuracies

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step > 0 and (state.global_step + 1) % self.eval_steps == 0:
            if state.is_world_process_zero:
                logging.info("\n--- Running per-bucket validation...")
                accs = self._run_custom_validation()
                log_data = {}
                for bkt, a in accs.items():
                    log_data[f"validation/accuracy_bucket_{bkt}"] = a
                log_data["validation/accuracy_overall"] = sum(accs.values()) / len(accs) if accs else 0.0
                stage_state = kwargs.get("stage_state")
                if isinstance(stage_state, dict) and "current" in stage_state:
                    log_data["validation/current_stage"] = stage_state["current"]
                try:
                    wandb.log(log_data)
                except Exception as e:
                    logging.warning(f"wandb.log failed: {e}")
        return control


# ========== Trainer subclass to plug our sampler ==========

class StagedSFTTrainer(SFTTrainer):
    def __init__(self, *args, stage_indices=None, stage_state=None, stage_chunk_steps=1,
                 major_prob=0.8, strict_mix=False, **kwargs):  # <---- NEW arg
        self._stage_indices = stage_indices
        self._stage_state = stage_state
        self._stage_chunk_steps = stage_chunk_steps
        self._major_prob = major_prob
        self._strict_mix = bool(strict_mix)                 # <---- store
        super().__init__(*args, **kwargs)

    def _get_train_sampler(self, train_dataset=None):
        return SoftSwitchCurriculumSampler(
            stage_indices=self._stage_indices,
            stage_state=self._stage_state,
            batch_size=self._train_batch_size,
            world_size=self.args.world_size,
            chunk_steps=self._stage_chunk_steps,
            seed=getattr(self.args, "seed", 42),
            major_prob=self._major_prob,
            strict_mix=self._strict_mix,                   # <---- pass through
        )

# ========== Main training ==========

def train(config):
    # Load the dataset
    dataset = load_dataset("json", data_files={
        "train": config["dataset"]["trainset_path"],
        "validation": config["dataset"]["valset_path"],
    })
        
    # Create bucket indices of each difficulty level
    n_buckets = max(item['difficulty'] for item in dataset['train']) + 1
    bucket_indices = {i: [] for i in range(n_buckets)}
    for i, item in enumerate(dataset['train']):
        bucket_indices[item['difficulty']].append(i)
    val_bucket_indices = {i: [] for i in range(n_buckets)}
    for i, item in enumerate(dataset['validation']):
        val_bucket_indices[item['difficulty']].append(i)

    logging.debug(f"bucket_indices: {bucket_indices}")
    logging.debug(f"val_bucket_indices: {val_bucket_indices}")

    # ---- Model & tokenizer ----
    logging.info("--- Initializing model & tokenizer ---")
    model = AutoModelForCausalLM.from_pretrained(
        config["models"]["student"],
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        config["models"]["student"],
        trust_remote_code=True,
    )
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # get the validation arguments
    validation_args = config.get("validation", {})
    # decide verify_mode during validation
    verify_mode = validation_args.get("verify_mode", "math" if "gsm" in config["dataset"]["name"].lower() else "nl")
    
    # build the validation function
    mode, sbert_backend, verifier, gen_eval = build_validation_from_cfg(
        model=model,
        tokenizer=tokenizer,
        val_cfg=validation_args,
        default_mode=verify_mode,
    )

    # Optional PEFT
    peft_config = None
    if "peft" in config:
        peft_config = LoraConfig(**config["peft"])
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()

    # ---- Training args ----
    train_args = SFTConfig(**config["training"])
    if not getattr(train_args, "max_steps", 0):
        raise ValueError("This staged curriculum is step-based. Please set training.max_steps > 0 in your config.")

    # ---- Shared stage state (callback <-> sampler) ----
    stage_state = {"current": -1, "prev": [], "order": []}

    # ---- Curriculum order (supports legacy 'reverse') ----
    curriculum_cfg = config.get("curriculum", {})
    order = curriculum_cfg.get("order", "normal")
    # Backward-compat: if 'order' is not provided, respect old boolean 'reverse'
    if "order" not in curriculum_cfg and curriculum_cfg.get("reverse", False):
        logging.warning("[Curriculum] 'reverse' is deprecated. Use 'order': 'reverse'. Applying reverse behavior.")
        order = "reverse"

    # Use the training seed if present, or allow an explicit order_seed
    order_seed = curriculum_cfg.get("order_seed", getattr(SFTConfig(**config["training"]), "seed", None))

    # ---- Callback computes schedule and flips stages on step boundaries ----
    stage_cb = StagedCurriculumCallback(
        stage_state=stage_state,
        num_stages=n_buckets,
        max_steps=train_args.max_steps,
        output_dir=train_args.output_dir,
        save_each_stage=curriculum_cfg.get("save_each_stage", False),
        order=order,
        seed=order_seed,
    )
    
    # ---- Eval logger callback ----
    eval_cb = StagedEvalLoggerCallback(
        model=model,
        tokenizer=tokenizer,
        val_bucketed_indices=val_bucket_indices,
        val_dataset=dataset["validation"],
        verifier_type="math" if "gsm" in config["dataset"]["name"].lower() else "nl",
        eval_steps=config["curriculum"].get("eval_steps", 50),
        eval_questions=config["curriculum"].get("eval_questions", 50),
        max_new_tokens=config["curriculum"].get("eval_max_new_tokens", 256),
        gen_eval=gen_eval,
    )

    # >>> CHANGED: read majority probability from config (defaults to 0.8)
    major_prob = float(config.get("curriculum", {}).get("soft_major_prob", 0.8))
    strict_mix = bool(config.get("curriculum", {}).get("strict_mix", False))  # <---- NEW


    # ---- Trainer ----
    trainer = StagedSFTTrainer(
        model=model,
        processing_class=tokenizer,
        args=train_args,
        train_dataset=dataset["train"],
        peft_config=peft_config,
        stage_indices=bucket_indices,
        stage_state=stage_state,
        stage_chunk_steps=config["curriculum"].get("stage_chunk_steps", 1),
        major_prob=major_prob,
        strict_mix=strict_mix,   # <---- NEW
        callbacks=[stage_cb, eval_cb],
    )

    logging.info(f"--- Starting staged curriculum training run (soft-switching, major_prob={major_prob:.2f}) ---")
    trainer.train()

    # Final save
    logging.info("--- Training complete ---")
    trainer.save_model(train_args.output_dir)
    tokenizer.save_pretrained(train_args.output_dir)
    logging.info("All curriculum learning stages completed.")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to the json config file")
    args = parser.parse_args()
    config = json.load(open(args.config))
    
    # Create a single run name string
    from datetime import datetime
    run_name = f"{config['dataset']['name']}_{config['models']['student'].split('/')[-1]}_{os.path.basename(args.config).split('.')[0]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    config["training"]["run_name"] = run_name
    
    train(config)

if __name__ == "__main__":
    main()
