

import json
import argparse
import logging
import os
from datasets import load_dataset, Dataset, DatasetDict
from typing import Callable, Optional, Dict, Union, List
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer,SFTConfig
import torch
from torch.utils.data import Sampler
from transformers import TrainerCallback
from trl import SFTTrainer
import random
import numpy as np
from peft import LoraConfig, get_peft_model

import wandb


from common import build_validation_from_cfg

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class DynamicCurriculumSampler(Sampler):
    def __init__(
        self,
        dataset,
        bucket_indices: dict,
        bucket_weights: dict,
        cursors: dict,
        per_device_batch_size: int,
        grad_accum_steps: int,
        num_replicas: int,
        rank: int,
        clock: dict,                          # << NEW: {'eval_steps': int, 'global_step': int}
        sampling_strategy='sequential',
    ):
        self.data = dataset
        self.bucket_indices = bucket_indices
        self.bucket_weights = bucket_weights
        self.cursors = cursors
        self.per_device_batch_size = per_device_batch_size
        self.grad_accum_steps = max(1, grad_accum_steps)
        self.num_replicas = num_replicas
        self.rank = rank
        self.clock = clock                   # << NEW
        self.sampling_strategy = sampling_strategy

        self.total_size = sum(len(v) for v in self.bucket_indices.values())
        assert self.total_size == len(dataset)

        # Track non-empty buckets once; will be refreshed per chunk
        self._nonempty_cache = None
    
    def _nonempty_bucket_ids(self):
        # Recompute every time to be safe if indices or weights change
        return [b for b, idxs in self.bucket_indices.items() if len(idxs) > 0]

    def _steps_until_next_update(self):
        # Count in *optimizer steps* (same unit as Trainer.state.global_step)
        gs = int(self.clock.get('global_step', 0))
        E = int(self.clock.get('eval_steps', 50))
        # how many optimizer steps until (global_step + 1) % E == 0 fires update?
        # remeber global_step is updated at on_step_end, so we want to check (gs + 1)
        steps = E - ((gs + 1) % E)
        if gs == 0 or steps == 0:
            steps = E  # right after an update: full interval ahead
        return steps

    def _make_chunk_indices(self, total_chunk_samples: int):
        # Only consider buckets that have data
        bucket_ids = self._nonempty_bucket_ids()
        if not bucket_ids:
            raise RuntimeError("All training buckets are empty.")
        # Build probabilities over non-empty buckets only
        w = np.array([max(0.0, float(self.bucket_weights.get(bid, 0.0))) for bid in bucket_ids], dtype=np.float64)
        s = w.sum()
        if s <= 0:
            # If all weights are zero/negative, fall back to uniform over non-empty
            w = np.ones_like(w) / len(w)
        else:
            w = w / s
        
        chunk = []
        for _ in range(total_chunk_samples):
            bid = np.random.choice(bucket_ids, p=w)

            # Safe because we only keep non-empty buckets
            if self.sampling_strategy == 'sequential':
                cur = self.cursors[bid] % len(self.bucket_indices[bid])
                idx = self.bucket_indices[bid][cur]
                self.cursors[bid] += 1
            else:
                idx = random.choice(self.bucket_indices[bid])

            chunk.append(idx)
        return chunk

    def __iter__(self):
        yielded = 0
        samples_per_step = self.per_device_batch_size * self.num_replicas

        while yielded < self.total_size:
            steps_until = self._steps_until_next_update()
            samples_in_chunk = samples_per_step * steps_until
            # logging
            if self.rank == 0:
                logging.info(
                    f"Sampler: global_step={self.clock.get('global_step')} eval_steps={self.clock.get('eval_steps')} "
                    f"=> steps_until_update={steps_until} | using weights {dict(self.bucket_weights)}"
                    f" | generating chunk of {samples_in_chunk} samples"
                )

            chunk = self._make_chunk_indices(samples_in_chunk)
            # yield samples
            for idx in chunk:
                yield idx
                yielded += 1
                if yielded >= self.total_size:
                    return

    def __len__(self):
        return self.total_size
    

class CurriculumTrainer(SFTTrainer):
    def __init__(self, sampling_strategy='sequential', *args, **kwargs):
        # Pop your custom arg before calling the parent constructor
        self.bucket_indices = kwargs.pop("bucket_indices")
        self.bucket_weights = kwargs.pop("bucketed_weights")
        self.cursors = kwargs.pop("cursors")
        self.sampling_strategy = sampling_strategy
        self.clock = kwargs.pop("clock")
        super().__init__(*args, **kwargs)

    def _get_train_sampler(self, train_dataset=None) -> torch.utils.data.Sampler:
        if self.args.world_size <= 1:
            raise NotImplementedError("This implementation is for distributed training.")
        
        return DynamicCurriculumSampler(
            dataset=train_dataset,
            bucket_indices=self.bucket_indices,
            bucket_weights=self.bucket_weights,
            cursors=self.cursors,
            per_device_batch_size=self._train_batch_size,
            grad_accum_steps=self.args.gradient_accumulation_steps,
            num_replicas=self.args.world_size,
            rank=self.args.process_index,
            sampling_strategy=self.sampling_strategy,
            clock=self.clock,
        )
        

class CurriculumCallback(TrainerCallback):
    def __init__(self, model, tokenizer, val_bucket_indices: dict, bucket_weights: dict, val_dataset, 
                policy='boltzmann', ema_alpha=0.5, ema_beta=0.5, epsilon=0.1, temperature=1.0, 
                eval_steps=10, eval_questions=50, clock: dict = None,
                gen_eval: Optional[object] = None, verifier_type: str = "math",
                train_bucket_indices: dict = None,
                **kwargs):
        # Store references to all required objects
        self.model = model
        self.tokenizer = tokenizer
        self.val_bucket_indices = val_bucket_indices
        self.bucket_weights = bucket_weights
        self.val_dataset = val_dataset # We need the full dataset to get ground truth
        self.eval_steps = eval_steps
        self.eval_questions = eval_questions
        self.train_bucket_indices = train_bucket_indices or {b: [] for b in bucket_weights.keys()}

        def _alive_ids():
            # sample/training time definition of "alive"
            return [b for b, idxs in self.train_bucket_indices.items() if len(idxs) > 0]
        self._alive_ids = _alive_ids
        
        self.clock = clock if clock is not None else {'eval_steps': eval_steps, 'global_step': 0}
        self.clock['eval_steps'] = int(eval_steps)  # ensure consistent

        # --- Parameters for validation verification ---
        self.gen_eval = gen_eval
        self.verifier_type = verifier_type  # "math" or "nl"
        assert self.verifier_type in ["math", "nl"], "verifier_type must be 'math' or 'nl'"
        
        # --- Parameters for the bandit algorithm ---
        self.policy = policy           # 'boltzmann' or 'egreedy'
        self.ema_alpha = ema_alpha     # Alpha for EMA of rewards
        self.ema_beta = ema_beta       # Beta for EMA of accuracies (baseline)
        self.temperature = temperature # Tau for Boltzmann distribution
        self.epsilon = epsilon         # Epsilon for e-greedy, min_prob_floor for Boltzmann

        # --- State for the bandit ---
        # Initialize expected returns (Q-values) to zero for all buckets
        self.ema_reward = {bucket_id: 0.0 for bucket_id in self.val_bucket_indices.keys()}
        self.ema_acc = {bucket_id: 0.0 for bucket_id in self.val_bucket_indices.keys()}

    def _run_custom_validation(self, model, tokenizer):
        """Runs the validation and returns a dictionary of new accuracies."""

        if self.gen_eval is None:
            raise ValueError("GenerationEvaluator (gen_eval) is not provided for validation.")
        
        logging.info("Running custom validation...")
        model.eval() # Set the model to evaluation mode
        accuracies = {}

        try: # Use a try...finally block to ensure the cache is cleared
            for bucket_id, indices in self.val_bucket_indices.items():
                if not indices:
                    accuracies[bucket_id] = 0.0
                    continue

                # Sample n questions (or fewer if not enough are available)
                sample_indices = random.sample(indices, min(self.eval_questions, len(indices)))
                acc = self.gen_eval.evaluate_indices(
                    dataset = self.val_dataset,
                    indices = sample_indices,
                    verify_mode = self.verifier_type
                )
                accuracies[bucket_id] = acc
                logging.info(f"  Bucket {bucket_id}: {len(sample_indices)} samples, accuracy = {acc:.2f}%")
        finally:
            # Free up cached memory before returning to the training loop
            torch.cuda.empty_cache()
            logging.debug("Cleared CUDA cache after validation.")
            model.train() # Set the model back to training mode

        return accuracies
    
    def on_step_end(self, args, state, control, **kwargs):
        # Keep clock updated every step in *optimizer step* units
        # (on_step_end fires for every micro-step too; state.global_step updates only at GA boundary)
        self.clock['global_step'] = int(state.global_step)
        
        # Trigger right *before* step 50/100/150... (i.e., at the end of 49/99/149...)
        if state.global_step > 0 and (state.global_step + 1) % self.eval_steps == 0:
            if state.is_world_process_zero:
                logging.info(f"\n--- (Rank 0) Curriculum step {state.global_step}: Observe, Reward, Update Q, Act ---")
                # Prepare a dictionary for logging
                log_data = {}
                # 1. Observe (ot): Run validation to get current scores
                current_acc = self._run_custom_validation(self.model, self.tokenizer)

                # 2. Reward (rt) and Q-Value Update with TWO EMAs (baseline acc & Q)
                logging.info("Updating EMA Rewards (EMA baseline + EMA Q)...")
                for b in self.ema_reward.keys():
                    acc = current_acc.get(b, 0.0)
                    prev_baseline = self.ema_acc[b]
                    # advantage-style reward using *previous* baseline
                    reward = acc - prev_baseline 
                    # update EMA baseline of accuracy
                    baseline = (1 - self.ema_beta) * prev_baseline + self.ema_beta * acc
                    self.ema_acc[b] = baseline   
                    # update Q as EMA of reward
                    ema_reward_old = self.ema_reward[b]
                    ema_reward_new = (1 - self.ema_alpha) * ema_reward_old + self.ema_alpha * reward
                    self.ema_reward[b] = ema_reward_new
                    logging.info(
                        f"  Bucket {b}: acc={acc:.2f}, baseline(prev)={prev_baseline:.2f}, "
                        f"reward={reward:.2f}, baseline(new)={baseline:.2f}, Q: {ema_reward_old:.2f}->{ema_reward_new:.2f}"
                    )    
                    # logging
                    log_data[f"validation/accuracy_bucket_{b}"] = acc
                    log_data[f"curriculum/baseline_acc_bucket_{b}"] = baseline
                    log_data[f"curriculum/reward_bucket_{b}"] = reward
                    log_data[f"curriculum/ema_reward_bucket_{b}"] = ema_reward_new

                # 3. Act (at): Choose next action by updating the bucket_weights based on the policy
                logging.info(f"Updating weights with {self.policy} policy...")
                alive = self._alive_ids()
                if not alive:
                    logging.warning("No alive (non-empty) training buckets; skipping weight update.")
                    return

                if self.policy == 'boltzmann':
                    # Build tensor over alive buckets only
                    ema_reward_vec = torch.tensor([self.ema_reward[b] for b in alive], dtype=torch.float32)
                    # p(a) = exp(Q(a)/τ) / Σ exp(Q(i)/τ)
                    probs = torch.softmax(ema_reward_vec / max(1e-8, float(self.temperature)), dim=0)
                    logging.debug(f"  Raw Boltzmann probabilities: {probs.numpy().tolist()}")
                    floor = min(self.epsilon, 1.0 / len(alive))
                    probs = torch.clamp(probs, min=floor)  # floor over alive only
                    logging.debug(f"  Clamped probabilities with floor {floor}: {probs.numpy().tolist()}")
                    probs = probs / probs.sum()
                    logging.debug(f"  Renormalized probabilities: {probs.numpy().tolist()}")
                    # Write back: alive get probs; empty get 0
                    for i, b in enumerate(alive):
                        self.bucket_weights[b] = probs[i].item()
                    for b in self.bucket_weights.keys():
                        if b not in alive:
                            self.bucket_weights[b] = 0.0
                
                elif self.policy == 'egreedy':
                    if len(alive) == 1:
                        # Only one place to sample from
                        sole = alive[0]
                        for b in self.bucket_weights.keys():
                            self.bucket_weights[b] = 1.0 if b == sole else 0.0
                    else:
                        # Spread epsilon among the non-best alive buckets
                        best_bucket_id = max(alive, key=lambda b: self.ema_reward[b])
                        floor = self.epsilon / (len(alive) - 1)
                        for b in self.bucket_weights.keys():
                            if b in alive:
                                self.bucket_weights[b] = floor
                            else:
                                self.bucket_weights[b] = 0.0
                        self.bucket_weights[best_bucket_id] = 1.0 - self.epsilon
                    
                else:
                    raise ValueError(f"Invalid policy: {self.policy}. Must be 'boltzmann' or 'egreedy'.")
                
                # After updating and normalizing weights, add them to the log
                for bucket_id in self.bucket_weights.keys():
                    log_data[f"curriculum/weight_bucket_{bucket_id}"] = self.bucket_weights.get(bucket_id, 0.0)
                # log overall accuracy too
                log_data["validation/accuracy_overall"] = sum(current_acc.values()) / len(current_acc) if current_acc else 0.0
                
                # --- Log everything to wandb ---
                wandb.log(log_data)
                logging.debug("Logged custom curriculum metrics to Weights & Biases.")
    
                logging.info(f"--- New weights for next epoch: {self.bucket_weights} ---")
            
            # Synchronize the updated weights across all processes
            # self._sync_weights()
            self._sync_weights_and_clock()
                
    def _sync_weights_and_clock(self):
        import torch.distributed as dist
        if dist.is_available() and dist.is_initialized():
            obj_list = [ (self.bucket_weights, self.clock) if dist.get_rank() == 0 else None ]
            dist.broadcast_object_list(obj_list, src=0)
            if dist.get_rank() != 0:
                new_w, new_clock = obj_list[0]
                self.bucket_weights.update(new_w)
                self.clock.update(new_clock)
    

def train(config):
    # Load the dataset
    dataset = load_dataset("json", data_files={
        "train": config["dataset"]["trainset_path"],
        "validation": config["dataset"]["valset_path"],
    })
    # cache_dir=os.environ.get("HF_DATASETS_CACHE"),  # local tmpfs/SSD
    # keep_in_memory=True)
        
    # dataset = dataset.shuffle(seed=config["dataset"]["seed"])
    
    # Create bucket indices of each difficulty level
    n_buckets = max(item['difficulty'] for item in dataset['train']) + 1
    bucket_indices = {i: [] for i in range(n_buckets)}
    for i, item in enumerate(dataset['train']):
        bucket_indices[item['difficulty']].append(i)
    val_bucket_indices = {i: [] for i in range(n_buckets)}
    for i, item in enumerate(dataset['validation']):
        val_bucket_indices[item['difficulty']].append(i)

    logging.debug(f"bucket_indices: {bucket_indices}")
    logging.debug(f"val_bucket_indices: {val_bucket_indices}")
 
    # Get the training arguments
    training_arguments = SFTConfig(**config["training"])
    if not getattr(training_arguments, "max_steps", 0):
        raise ValueError("This self-evolve curriculum is step-based. Please set training.max_steps > 0 in your config.")

    # Initialize the model and tokenizer
    student_model = AutoModelForCausalLM.from_pretrained(
        config["models"]["student"],
        trust_remote_code=True
    )
    student_tokenizer = AutoTokenizer.from_pretrained(
        config["models"]["student"], 
        trust_remote_code=True
    )
    student_tokenizer.padding_side = 'left'
    if student_tokenizer.pad_token is None:
        student_tokenizer.pad_token = student_tokenizer.eos_token # Use the EOS token as the padding token

    # Get the validation arguments
    validation_arguments = config.get("validation", {})
    # decide verify_mode during validation
    verify_mode = 'math' if 'gsm' in config["dataset"]["name"].lower() else 'nl'
    
    mode, sbert_backend, verifier, gen_eval = build_validation_from_cfg(
        model=student_model,
        tokenizer=student_tokenizer,
        val_cfg=validation_arguments,
        default_mode=verify_mode,
    )

    # Use PEFT if specified
    if "peft" in config:
        peft_config = LoraConfig(**config["peft"])
        student_peft_model = get_peft_model(student_model, peft_config)
        student_peft_model.print_trainable_parameters()
    else:
        peft_config = None
        
    # ---------------- Callback and Sampler Shared State ----------------
    bucket_weights = {}
    cursors = {}
    for bucket_id, idxs in bucket_indices.items():
        # Give no probability to empty buckets; keep keys for logging consistency
        bucket_weights[bucket_id] = 1.0 if len(idxs) > 0 else 0.0
        cursors[bucket_id] = 0
    eval_steps = config['curriculum'].get('eval_steps', 10)
    shared_clock = {'eval_steps': int(eval_steps), 'global_step': 0}  # << NEW

    # Create the callback, passing it BOTH the indices and the weights
    curriculum_callback = CurriculumCallback(
        model=student_model,
        tokenizer=student_tokenizer,
        val_bucket_indices=val_bucket_indices,
        train_bucket_indices=bucket_indices,            # << NEW
        bucket_weights=bucket_weights,
        val_dataset=dataset['validation'],
        verifier_type="math" if "gsm" in config["dataset"]["name"].lower() else "nl",
        gen_eval=gen_eval,
        eval_steps=eval_steps,
        eval_questions=config['curriculum'].get('eval_questions', 50),
        policy=config['curriculum'].get('policy', 'boltzmann'),
        ema_alpha=config['curriculum'].get('alpha', 0.5),
        ema_beta=config['curriculum'].get('beta', 0.5),
        epsilon=config['curriculum'].get('epsilon', 0.1),
        temperature=config['curriculum'].get('temperature', 1.0),
        clock=shared_clock,
    )

    try:
        job_type = config["job_type"]
        if "self_evolve_cl" in job_type:
            trainer = CurriculumTrainer(
                model=student_model,
                processing_class=student_tokenizer,
                args=training_arguments,
                train_dataset=dataset['train'],
                peft_config=peft_config,
                # Pass your custom arguments here
                bucket_indices=bucket_indices,
                bucketed_weights=bucket_weights,
                cursors=cursors,
                sampling_strategy=config['curriculum'].get('sampling_strategy', 'sequential'),
                clock=shared_clock,
                # Add the callback
                callbacks=[curriculum_callback],
            )
        else:
            logging.error(f"Invalid job type: {job_type}")
            raise ValueError(f"Invalid job type: {job_type}")
    except ValueError as e:
        logging.error(f"Training job terminated: {e}")
        return
        
    trainer.train()
    trainer.save_model(config["training"]["output_dir"])
    student_tokenizer.save_pretrained(config["training"]["output_dir"])


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='path to the json config file')
    args = parser.parse_args()
    config = json.load(open(args.config))
    
    # Create a single run name string
    from datetime import datetime
    run_name = f"{config['dataset']['name']}_{config['models']['student'].split('/')[-1]}_{os.path.basename(args.config).split('.')[0]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    # inject run_name into training config (used by SFTConfig/TrainingArguments)
    config["training"]["run_name"] = run_name
    
    train(config)


if __name__ == "__main__":
    main()