import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback
import transformers
import tqdm
import wandb


from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer, get_peft_config, ModelConfig
from utils_deepscaler import *
import argparse
import random
import numpy as np
import torch.distributed as dist
from typing import Optional, Sized
from torch.utils.data import Sampler
import os
import sys
import getpass
from pathlib import Path
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import tqdm

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
# Step 6.5: Select top-K closest samples per cluster
from scipy.spatial.distance import cdist


from reward_funcs.reward_fn import score_deepscaler
from bandit import ClusterBanditSelector

user_name = getpass.getuser()
PATH_TO_REPO = Path(f"/scratch/{user_name}/UncertainReasoning")

B = 8

# ------------------------------------------------------
# 1. Callback that updates current_R_avg and current_T in memory
# ------------------------------------------------------
class CurriculumUpdateCallback(TrainerCallback):
    def __init__(self):
        super().__init__()
        self.trainer_ref = None

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None or self.trainer_ref is None:
            return

        # Your logs contain "reward" (not "avg_reward")
        if "reward" in logs:
            trn = self.trainer_ref

            # 1) Update R_avg
            trn.current_R_avg = float(logs["reward"])

            # 2) Compute new T
            old_T = trn.current_T
            sigma = trn.sensitivity
            eta = trn.eta
            beta = trn.beta
            d_min = trn.d_min
            d_max = trn.d_max
            R_avg = trn.current_R_avg

            increment = float(eta * np.tanh(sigma * (R_avg - beta)))
            T_prime = float(np.clip(old_T + increment, d_min, d_max))
            trn.current_T = T_prime

        else:
            return


class WandbTrainingCallback(TrainerCallback):
    """
    Forwards only the standard GRPO metrics (loss, avg_reward, etc.) to WandB.
    We do NOT send current_T or current_R_avg to WandB here.
    """
    def __init__(self):
        super().__init__()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            import wandb
            wandb.log(logs)


# ------------------------------------------------------
# 2. Custom Sampler: picks top-B prompts whose difficulty is closest to T
# ------------------------------------------------------
class RepeatRandomSampler(Sampler):
    def __init__(
        self,
        data_source: Sized,
        repeat_count: int,
        trainer_ref, # pass the entire Trainer, not just a float T
        seed: Optional[int] = None,
        batch_num: Optional[int] = B,
        is_train: Optional[bool] = None,
        mode: Optional[str] = "uniform",
        n_clusters: Optional[int] = None,  # Number of clusters for clustering mode
    ):
        self.data_source = data_source
        self.repeat_count = repeat_count
        self.batch_num = batch_num
        self.trainer_ref = trainer_ref
        self.mode = mode
        self.seed = seed
        self.num_samples = len(data_source)
        self.is_train = is_train
        self.current_eth = 0
        # sort self.data_source by difficulty
        self.sorted_data_source = sorted(
            self.data_source, 
            key=lambda x: x["extra_info"]["difficulty"]
        )
        
        if "skew_easy" in orig_args.dataset_path:
            data_dist = "deepscaler_skew_easy"
        elif "skew_difficult" in orig_args.dataset_path:
            data_dist = "deepscaler_skew_difficult"
        elif "uniform" in orig_args.dataset_path:
            data_dist = "deepscaler_uniform"
        else:
            raise ValueError(f"Unknown dataset path: {orig_args.dataset_path}. Supported paths: deepscaler_skew_easy, deepscaler_skew_difficult, deepscaler_uniform.")
        
        if 'gsm8k' in orig_args.dataset_path.lower():
            data_name = "gsm8k"
        else:
            data_name = "deepscaler"

        if self.mode in ["cluster", "cluster_thompson", "cluster_thompson_ema"]:
            self.cluster_dir = f"/PATH/TO/CLUSTER/DIR"
            self.reward_dir = f"/PATH?TO/REWARD/DIR"

        self.generator = torch.Generator()
        if seed is not None:
            self.generator.manual_seed(seed)

        if n_clusters is not None:
            self.cluster_map = {i: [] for i in range(n_clusters)}
            for idx in range(self.num_samples):
                cluster_id = data_source[idx]["extra_info"]["cluster_id"]
                self.cluster_map[cluster_id].append(idx)

            # Initialize Bandit
            self.bandit = ClusterBanditSelector(n_clusters=n_clusters, epsilon=0.3)

    def __iter__(self):
        # Each time __iter__ is called, fetch the *current* T and R_mean from the trainer:
        T = float(self.trainer_ref.current_T)
        R_mean = float(self.trainer_ref.current_R_avg)

        if self.is_train:
            if self.mode == "adarft":
                difficulties = [
                    self.data_source[i]["extra_info"]["difficulty"]
                    for i in range(self.num_samples)
                ]
                sorted_idx = sorted(
                    range(self.num_samples),
                    key=lambda i: abs(difficulties[i] - T)
                )[: self.batch_num]

                indexes = [idx for idx in sorted_idx for _ in range(self.repeat_count)]
                # print("Indexes for deepscaler_skew_easy:", indexes)
                # print("Dynamic Current T:", T)
                # print("Dynamic Current R_mean:", R_mean)
                # print("=============")
                with open(self.reward_dir, "a") as f:
                    f.write(f"{self.trainer_ref.current_R_avg}\n")

                selected_difficulties = [self.data_source[i]["extra_info"]["difficulty"] for i in indexes]
                average_difficulties = np.mean(selected_difficulties)

                with open(self.cluster_dir, "a") as f:
                    f.write(f"{average_difficulties}\n")

                return iter(indexes)

            elif self.mode == "no_curr":
                # No curriculum: just repeat all indices uniformly
                perm = torch.randperm(self.num_samples, generator=self.generator).tolist()
                indexes = [idx for idx in perm for _ in range(self.repeat_count)]
                return iter(indexes)
            

            elif self.mode == "uncertain":
                # NOTE: This is exactly like "adarft". Will fix when there is new algo.
                difficulties = [
                    self.data_source[i]["extra_info"]["difficulty"]
                    for i in range(self.num_samples)
                ]
                sorted_idx = sorted(
                    range(self.num_samples),
                    key=lambda i: abs(difficulties[i] - T)
                )[: self.batch_num]

                indexes = [idx for idx in sorted_idx for _ in range(self.repeat_count)]
                return iter(indexes)
            
            elif self.mode == "cluster":
                # ignore the first update, we only update the clusters after the first epoch
                if not hasattr(self.trainer_ref, "last_selected_cluster"):
                    pass
                else:
                    # update the bandit with the last selected cluster and its reward
                    last_cluster = self.trainer_ref.last_selected_cluster
                    last_reward = self.trainer_ref.current_R_avg
                    self.bandit.update(last_cluster, last_reward)
                    with open(self.reward_dir, "a") as f:
                        f.write(f"{self.trainer_ref.current_R_avg}\n")
                selected_cluster = self.bandit.select_cluster()

                # write the selected cluster to the cluster_dir
                with open(self.cluster_dir, "a") as f:
                    f.write(f"{selected_cluster}\n")
                

                candidate_indices = self.cluster_map[selected_cluster]

                # Randomly sample batch_num indices from this cluster
                sampled = random.sample(candidate_indices, min(self.batch_num, len(candidate_indices)))
                indexes = [idx for idx in sampled for _ in range(self.repeat_count)]

                self.trainer_ref.last_selected_cluster = selected_cluster
                
                return iter(indexes)

            elif self.mode == "cluster_thompson":
                if not hasattr(self.trainer_ref, "last_selected_cluster"):
                    pass
                else:
                    # update the bandit with the last selected cluster and its reward
                    last_cluster = self.trainer_ref.last_selected_cluster
                    last_reward = self.trainer_ref.current_R_avg
                    self.bandit.update_thompson(last_cluster, last_reward)
                    with open(self.reward_dir, "a") as f:
                        f.write(f"{self.trainer_ref.current_R_avg}\n")
                selected_cluster = self.bandit.select_cluster_thompson()
                # write the selected cluster to the cluster_dir
                with open(self.cluster_dir, "a") as f:
                    f.write(f"{selected_cluster}\n")
                candidate_indices = self.cluster_map[selected_cluster]

                # Randomly sample batch_num indices from this cluster
                sampled = random.sample(candidate_indices, min(self.batch_num, len(candidate_indices)))
                indexes = [idx for idx in sampled for _ in range(self.repeat_count)]
                # Store for reward update later (use trainer_ref)
                self.trainer_ref.last_selected_cluster = selected_cluster
                return iter(indexes)
            
            elif self.mode == "cluster_thompson_ema":
                if not hasattr(self.trainer_ref, "last_selected_cluster"):
                    pass
                else:
                    # update the bandit with the last selected cluster and its reward
                    last_cluster = self.trainer_ref.last_selected_cluster
                    last_reward = self.trainer_ref.current_R_avg
                    self.bandit.update_thompson_ema(last_cluster, last_reward)
                    with open(self.reward_dir, "a") as f:
                        f.write(f"{self.trainer_ref.current_R_avg}\n")
                selected_cluster = self.bandit.select_cluster_thompson_ema()
                # write the selected cluster to the cluster_dir
                with open(self.cluster_dir, "a") as f:
                    f.write(f"{selected_cluster}\n")
                candidate_indices = self.cluster_map[selected_cluster]

                # Randomly sample batch_num indices from this cluster
                sampled = random.sample(candidate_indices, min(self.batch_num, len(candidate_indices)))
                indexes = [idx for idx in sampled for _ in range(self.repeat_count)]
                # Store for reward update later (use trainer_ref)
                self.trainer_ref.last_selected_cluster = selected_cluster
                return iter(indexes)
            
            elif self.mode == "easy_to_hard":
                # take self.sorted_data_source to self.sorted_data_source + B indices
                list_indices = list(range(self.current_eth, self.current_eth + B))
                indexes = [idx for idx in list_indices for _ in range(self.repeat_count)]
                self.current_eth += B
                if self.current_eth >= self.num_samples:
                    self.current_eth = 0

                selected_difficulties = [self.data_source[i]["extra_info"]["difficulty"] for i in indexes]
                average_difficulties = np.mean(selected_difficulties)

                with open(self.cluster_dir, "a") as f:
                    f.write(f"{average_difficulties}\n")

                return iter(indexes)


            else:
                raise ValueError(f"Unknown mode: {self.mode}. Supported modes: uniform, uncertain, deepscaler_skew_easy, deepscaler_skew_difficult, deepscaler_easy_extreme, deepscaler_hard_extreme.")


        else:
            # For evaluation: purely random repeat of all indices
            perm = torch.randperm(self.num_samples, generator=self.generator).tolist()
            indexes = [idx for idx in perm for _ in range(self.repeat_count)]
            return iter(indexes)

    def __len__(self):
        return self.repeat_count * self.batch_num if self.is_train else self.num_samples * self.repeat_count 


# ------------------------------------------------------
# 3. Helper to run “generate” and return accuracy
# ------------------------------------------------------

def generate_answer(
    model,
    tokenizer,
    tokenized_samples,
    batch_size,
    max_completion_length
):
    # Exactly the same as before: run inference, return ACCURACY only.
    if dist.get_rank() == 0:
        device = model.device
        predictions = []
        generation_config = transformers.GenerationConfig(
            max_new_tokens=max_completion_length,
            do_sample=False,
            repetition_penalty=1.0,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
        model.eval()
        count = len(tokenized_samples)

        status = tqdm.tqdm(tokenized_samples, desc=f"Correct: 0/{count}")
        for i in range(0, count, batch_size):
            batches = tokenized_samples[i : i + batch_size]
            with torch.inference_mode():
                longest = max(len(b[0]) for b in batches)
                padded_input_ids = torch.stack(
                    [
                        torch.tensor([tokenizer.pad_token_id] * (longest - len(ids)) + ids)
                        for ids, _ in batches
                    ]
                ).to(device)
                attn_mask = torch.stack(
                    [tokens.ne(tokenizer.pad_token_id) for tokens in padded_input_ids]
                ).to(device)

                output = model.generate(
                    input_ids=padded_input_ids,
                    attention_mask=attn_mask,
                    generation_config=generation_config,
                )

                for j, generated in enumerate(output):
                    response = tokenizer.decode(
                        generated[len(padded_input_ids[j]) :], skip_special_tokens=True
                    )
                    prediction = extract_xml_answer(response)
                    predictions.append(batches[j][1] == prediction)

                status.update(len(batches))
                status.set_description(f"Correct: {sum(predictions)}/{count}")

        return np.mean(predictions)
    return 0

def tokenize_validation(tokenizer, samples, max_prompt_length):
    tokenized_samples = []
    for sample in samples:
        prompt = sample["prompt"]
        answer = sample["answer"]
        ids = tokenizer.apply_chat_template(
            prompt,
            add_generation_prompt=True,
            truncation=False,
            max_length=max_prompt_length,
        )
        tokenized_samples.append((ids, answer))
    return tokenized_samples


# ------------------------------------------------------
# 4. EvalTrainer subclass: holds current_T and current_R_avg,
#    and passes them into the sampler on every epoch.
# ------------------------------------------------------

class EvalTrainer(GRPOTrainer):
    def __init__(self, model, processing_class, reward_funcs, training_args, train_dataset, eval_dataset, orig_args, best_k):
        # Pass only the GRPOConfig (training_args) to super()
        super().__init__(model=model, processing_class=processing_class, reward_funcs=reward_funcs,
                         args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)

        self.orig_args = orig_args  # Store the original parser args for ADA-RFT
        # Copy ADA-RFT values from the original parser args (orig_args) onto the trainer itself
        self.T = float(orig_args.T)
        self.eta = float(orig_args.eta)
        self.sensitivity = float(orig_args.sensitivity)
        self.beta = float(orig_args.beta)
        self.d_min = float(orig_args.d_min)
        self.d_max = float(orig_args.d_max)

        self.current_T = self.T
        self.current_R_avg = 0.0

        self.best_k = best_k  # Store the best K for clustering mode

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        tokenized_samples = tokenize_validation(self.processing_class, self.eval_dataset, self.args.max_prompt_length)
        eval_acc = generate_answer(
            self.model,
            self.processing_class,
            tokenized_samples,
            self.args.per_device_eval_batch_size,
            self.args.max_completion_length,
        )

        output = {
            f"{metric_key_prefix}_accuracy": eval_acc,
            "epoch": self.state.epoch,
        }

        self.log(output)
        return output

    def _get_train_sampler(self, train_dataset=None) -> Sampler:
        """
        Every time DataLoader is rebuilt (start of each epoch), pass in the latest
        current_T and current_R_avg. On epoch 1, current_R_avg == 0, so pivot = initial ada_T.
        """
        return RepeatRandomSampler(
            data_source=self.train_dataset,
            repeat_count=self.args.num_generations,
            trainer_ref = trainer,
            seed=self.args.seed,
            batch_num=B,
            is_train=True,
            mode=self.orig_args.mode,
            n_clusters=self.best_k,  # Pass the best K for clustering mode
        )

    def _get_eval_sampler(self, eval_dataset=None) -> Sampler:
        return RepeatRandomSampler(
            data_source=self.eval_dataset,
            repeat_count=self.args.num_generations,
            trainer_ref = trainer,         
            seed=self.args.seed,
            batch_num=None,
            is_train=False,
            mode=None,  # No need for mode in eval sampler
        )

# ------------------------------------------------------
# 5. Main: parse arguments, set up trainer, and run
# ------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train GRPO with ADA-RFT Curriculum")
    parser.add_argument("--testing", action="store_true", help="Whether to run in testing mode", default=False)
    parser.add_argument("--model_name", type=str, required=False, default="Qwen/Qwen2.5-0.5B-Instruct")
    parser.add_argument("--mode", type=str, required=True, choices=["adarft", "no_curr", "uncertain", "cluster", "cluster_thompson_ema", "cluster_thompson", "easy_to_hard", "HARD_ONLY"])
    parser.add_argument("--sentence_transformers_model", type=str, required=False, default="Qwen/Qwen3-Embedding-0.6B")
    parser.add_argument("--uncertainty_metric", type=str, required=True, default=None, help="The metric used to measure uncertainty")
    parser.add_argument("--dataset_path", type=str, required=True, default=None, help="Path to the dataset (if needed)")
    parser.add_argument("--num_shots", type=int, required=False, default=0)
    parser.add_argument("--nepochs", type=int, required=False, default=10) # equivalent to 10000 samples
    parser.add_argument("--seed", type=int, required=False, default=2025)
    parser.add_argument("--bs", type=int, required=False, default=1)
    parser.add_argument("--gc", type=int, required=False, default=8)
    parser.add_argument("--L", type=int, required=False, default=1200)
    parser.add_argument("--do_eval", type=int, required=False, default=0)
    parser.add_argument("--n_clusters", type=int, required=False, default=10, help="Number of clusters for clustering mode")
    
    parser.add_argument("--n_pca_components", type=int, required=False, default=50, help="Number of PCA components to reduce to")

    parser.add_argument("--test", action="store_true", help="Run in test mode (no training, only evaluation)", default=False)
    parser.add_argument("--cluster_method", type=str, required=False, default="closest", choices=["closest", "diverse", "random"],
                        help="Method to select samples from clusters: 'closest' or 'diverse'")
    parser.add_argument("--remove_solverate", action="store_true", help="Remove solve rate from dataset", default=False)

    # ——— ADA-RFT hyperparameters in TEXT form ———
    parser.add_argument("--T", type=float, required=False, default=0.0, help="Initial target difficulty T ")
    parser.add_argument("--eta", type=float, required=False, default=50, help="Step size eta")
    parser.add_argument("--sensitivity", type=float, required=False, default=2.0, help="Sensitivity sigma (used inside tanh)")
    parser.add_argument("--beta", type=float, required=False, default=0.4, help="Target reward beta")
    parser.add_argument("--d_min", type=float, required=False, default=0.0, help="Lower bound on difficulty")
    parser.add_argument("--d_max", type=float, required=False, default=100.0, help="Upper bound on difficulty")

    orig_args = parser.parse_args()

    # Fix random seeds for reproducibility
    random.seed(orig_args.seed)
    torch.manual_seed(orig_args.seed)
    np.random.seed(orig_args.seed)

    # (Optional) initialize WandB for standard metrics only
    import wandb
    wandb.init(project="GRPO_training_ADARFT_text_notation", name=f"{orig_args.model_name}-shots{orig_args.num_shots}-seed{orig_args.seed}", config=vars(orig_args))

    data_name = "deepscaler"

    # Only reward function is score_deepscaler
    reward_list = [score_deepscaler]


    # Load datasets (each sample must have sample["extra_info"]["difficulty"])
    train_dataset = get_deepscaler_questions(orig_args, split="train", mode=orig_args.mode)
    eval_dataset  = get_gsm8k_questions(split="test")

    if orig_args.test:
        # Step 1: Load data
        prompts = [sample["prompt"] for sample in train_dataset]
        difficulties = [sample["extra_info"]["difficulty"] for sample in train_dataset]  # difficulty = 1 - solve rate

        # Step 2: Encode prompts
        sentence_transformers_model = SentenceTransformer(orig_args.sentence_transformers_model)
        embeddings = sentence_transformers_model.encode(prompts, convert_to_numpy=True)

        # Step 3: Reduce embedding dimensions with PCA
        pca_embed = PCA(n_components=orig_args.n_pca_components)
        reduced_embeddings = pca_embed.fit_transform(embeddings)

        # Step 4: Combine with difficulty (if applicable)
        if orig_args.remove_solverate:
            combined = reduced_embeddings
        else:
            combined = np.concatenate([reduced_embeddings, np.array(difficulties).reshape(-1, 1)], axis=1)

        # Step 5: Standardize all features
        scaler = StandardScaler()
        combined_features = scaler.fit_transform(combined)

        # Step 6: Run KMeans with K clusters
        K = orig_args.n_clusters
        kmeans = KMeans(n_clusters=K, random_state=42)
        cluster_labels = kmeans.fit_predict(combined_features)

        # Compute distances of each sample to each cluster center
        distances = cdist(combined_features, kmeans.cluster_centers_, metric='euclidean')

        # Step 6b: Select samples per cluster
        selected_indices = []
        K_per_cluster = 10
        CANDIDATE_POOL_SIZE = 50

        for cid in range(K):
            cluster_indices = np.where(cluster_labels == cid)[0]
            cluster_dists = distances[cluster_indices, cid]
            sorted_indices = cluster_indices[np.argsort(cluster_dists)]

            if orig_args.cluster_method == "closest":
                # Take top-K closest
                top_k = sorted_indices[:min(K_per_cluster, len(sorted_indices))]
                selected_indices.extend(top_k)

            elif orig_args.cluster_method == "diverse":
                # Take top-N closest, then diverse among them
                top_pool = sorted_indices[:min(CANDIDATE_POOL_SIZE, len(sorted_indices))]
                if len(top_pool) == 0:
                    continue
                diverse_selected = [top_pool[0]]
                for _ in range(1, min(K_per_cluster, len(top_pool))):
                    remaining = list(set(top_pool) - set(diverse_selected))
                    dists = cdist(combined_features[remaining], combined_features[diverse_selected])
                    min_dists = dists.min(axis=1)
                    next_idx = remaining[np.argmax(min_dists)]
                    diverse_selected.append(next_idx)
                selected_indices.extend(diverse_selected)
            elif orig_args.cluster_method == "random":
                # Randomly select K_per_cluster samples from the cluster
                if len(cluster_indices) > 0:
                    selected = np.random.choice(cluster_indices, size=min(K_per_cluster, len(cluster_indices)), replace=False)
                    selected_indices.extend(selected)
            else:
                raise ValueError(f"Unknown cluster_method: {orig_args.cluster_method}")

        # Finalize selection
        selected_indices = [int(idx) for idx in selected_indices]
        train_dataset = [train_dataset[i] for i in selected_indices]
        prompts = [prompts[i] for i in selected_indices]
        difficulties = [difficulties[i] for i in selected_indices]
        reduced_embeddings = reduced_embeddings[selected_indices]
        combined_features = combined_features[selected_indices]
        cluster_labels = cluster_labels[selected_indices]

        # Step 7: t-SNE visualization (on selected samples only)
        if orig_args.n_clusters == 1:
            pass
        else:
            tsne = TSNE(n_components=2, perplexity=30, learning_rate='auto', init='pca', random_state=42)
            tsne_data = tsne.fit_transform(combined_features)

            plt.figure(figsize=(8, 6))
            plt.scatter(tsne_data[:, 0], tsne_data[:, 1], c=cluster_labels, cmap='tab10', s=5)
            plt.title(f"t-SNE of Clusters (K={orig_args.n_clusters}, Method={orig_args.cluster_method})")
            plt.colorbar(label='Cluster ID')
            plt.tight_layout()            

        # Step 8: Print difficulty stats per cluster
        for cid in range(K):
            avg_diff = np.mean([difficulties[i] for i in range(len(difficulties)) if cluster_labels[i] == cid])
            print(f"Cluster {cid}: Avg Difficulty = {avg_diff:.3f}")

        # Step 9: Store cluster labels back into dataset
        new_data = []
        for i, sample in enumerate(train_dataset):
            sample = dict(sample)
            sample["extra_info"]["cluster_id"] = int(cluster_labels[i])
            new_data.append(sample)

        train_dataset = Dataset.from_list(new_data)

        # Step 10: Clean up and set output paths
        del sentence_transformers_model  # Free memory

        if orig_args.sentence_transformers_model == "Qwen/Qwen3-Embedding-0.6B":
            if orig_args.remove_solverate:
                output_dir = f"{PATH_TO_REPO}/output/REMOVE_SOLVERATE_TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{orig_args.model_name}-GRPO-{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
                run_name = f"REMOVE_SOLVERATE_TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{orig_args.model_name}-GRPO-ADARFT_text_notation-shots{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
            else:
                if "skew_easy" in orig_args.dataset_path:
                    output_dir = f"{PATH_TO_REPO}/output/TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{data_name}_skew_easy_{orig_args.model_name}-GRPO-{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
                    run_name = f"TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{data_name}_skew_easy_{orig_args.model_name}-GRPO-ADARFT_text_notation-shots{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
                elif "skew_difficult" in orig_args.dataset_path:
                    output_dir = f"{PATH_TO_REPO}/output/TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{data_name}_skew_difficult_{orig_args.model_name}-GRPO-{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
                    run_name = f"TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{data_name}_skew_difficult_{orig_args.model_name}-GRPO-ADARFT_text_notation-shots{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
                elif "uniform" in orig_args.dataset_path:
                    output_dir = f"{PATH_TO_REPO}/output/TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{data_name}_uniform_{orig_args.model_name}-GRPO-{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
                    run_name = f"TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{data_name}_uniform_{orig_args.model_name}-GRPO-ADARFT_text_notation-shots{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
                else:
                    raise ValueError(f"Unknown dataset path: {orig_args.dataset_path}")
                # output_dir = f"{PATH_TO_REPO}/output/TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{orig_args.model_name}-GRPO-{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
                # run_name = f"TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{orig_args.model_name}-GRPO-ADARFT_text_notation-shots{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-ncluster{orig_args.n_clusters}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
        else:
            output_dir = f"{PATH_TO_REPO}/output/{orig_args.sentence_transformers_model}_{data_name}_TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{orig_args.n_clusters}clusters_selected_train_samples_seed{orig_args.seed}"
            run_name = f"{orig_args.sentence_transformers_model}_{data_name}_TEST_VISUALIZE_{orig_args.cluster_method.upper()}_{orig_args.n_clusters}clusters_selected_train_samples_seed{orig_args.seed}"
    else:
        # Create output directory if needed
        if "skew_easy" in orig_args.dataset_path:
            output_dir = f"{PATH_TO_REPO}/output/{orig_args.model_name}_skew_easy-GRPO-{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
            run_name = f"{orig_args.model_name}-GRPO-ADARFT_text_notation-shots{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
        elif "skew_difficult" in orig_args.dataset_path:
            output_dir = f"{PATH_TO_REPO}/output/{orig_args.model_name}_skew_difficult-GRPO-{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
            run_name = f"{orig_args.model_name}-GRPO-ADARFT_text_notation-shots{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
        elif "uniform" in orig_args.dataset_path:
            output_dir = f"{PATH_TO_REPO}/output/{orig_args.model_name}_uniform-GRPO-{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
            run_name = f"{orig_args.model_name}-GRPO-ADARFT_text_notation-shots{orig_args.num_shots}-seed{orig_args.seed}-mode{orig_args.mode}-uncertainmetric{orig_args.uncertainty_metric}-T{orig_args.T}-eta{orig_args.eta}-sensitivity{orig_args.sensitivity}-beta{orig_args.beta}-d_min{orig_args.d_min}-d_max{orig_args.d_max}"
        else:
            raise ValueError(f"Unknown dataset path: {orig_args.dataset_path}")    
    print("SAVING TO:", output_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    do_eval = True
    eval_strategy = "steps"
    if orig_args.do_eval == 0:
        do_eval = False
        eval_strategy = "no"

    # Create GRPOConfig with only its expected fields
    training_args = GRPOConfig(
        output_dir=output_dir,
        run_name=run_name,
        eval_strategy=eval_strategy,
        eval_steps=50,
        do_eval=do_eval,
        learning_rate=5e-6,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        logging_steps=1,
        bf16=True,
        per_device_train_batch_size=8,  # equals num_generations
        gradient_accumulation_steps=orig_args.gc,
        num_generations=8,
        max_prompt_length=1024,
        max_completion_length=orig_args.L,
        num_train_epochs=orig_args.nepochs,
        save_steps=200,
        max_grad_norm=0.1,
        log_on_each_node=False,
        use_vllm=True,
        vllm_gpu_memory_utilization=0.1,
        vllm_device="cuda:0",
        report_to="tensorboard",
        seed=orig_args.seed,
    )

    # Load the model (same as before)
    model = AutoModelForCausalLM.from_pretrained(
        orig_args.model_name,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        device_map=None,
        use_cache=False,
        offload_state_dict=True,
    ).to("cuda")

    tokenizer = AutoTokenizer.from_pretrained(orig_args.model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # Instantiate our EvalTrainer, passing both GRPOConfig (training_args) and parser-namespace (orig_args)
    if not orig_args.test:
        best_k = None
    elif orig_args.mode in ["adarft"]:
        best_k = None  # ADA-RFT does not use clustering
    else:
        best_k = K

    trainer = EvalTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_list,
        training_args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        orig_args=orig_args,  # give it access to T, eta, sensitivity, beta, d_min, d_max
        best_k=best_k,
    )

    # Register only the callbacks we need:
    # 1) CurriculumUpdateCallback (updates T and R_avg in memory only)
    # 2) WandbTrainingCallback   (forwards standard loss/avg_reward to WandB)
    curcallback = CurriculumUpdateCallback()
    curcallback.trainer_ref = trainer
    trainer.add_callback(curcallback)

    trainer.add_callback(WandbTrainingCallback())

    # Start training (curriculum updates happen purely in memory)
    trainer.train()