import os
import warnings
import logging
import inspect

# Set PYTHONWARNINGS env var so subprocesses and early imports respect it
os.environ.setdefault("PYTHONWARNINGS", "ignore")
# Ignore all warnings via the warnings module
warnings.simplefilter("ignore")
warnings.filterwarnings("ignore")


# Replace warnings.showwarning and warnings.warn with no-ops to silence any direct calls
def _noop_showwarning(*_args, **_kwargs):
    return


def _noop_warn(*_args, **_kwargs):
    return


warnings.showwarning = _noop_showwarning
warnings.warn = _noop_warn
# Configure logging to show only ERROR or higher to silence warning-level logs
logging.basicConfig(level=logging.ERROR)
logging.getLogger().setLevel(logging.ERROR)
# Additionally silence noisy libraries explicitly (optional)
for _name in [
    "transformers",
    "torch",
    "pydantic",
    "torch.distributed.run",
    "accelerate",
]:
    try:
        logging.getLogger(_name).setLevel(logging.ERROR)
    except Exception:
        pass

import torch
import torch.distributed as dist
import wandb

from data_utils import (
    get_countdown_questions,
    get_gsm8k_questions,
    get_math_questions,
    get_sudoku_questions,
    set_random_seed,
    set_trainer_type,
)
from peft import LoraConfig

_orig_lora_init = LoraConfig.__init__


def _robust_lora_init(self, *args, **kwargs):
    # Dynamically inspect the __init__ method to find valid parameters
    valid_params = set(inspect.signature(_orig_lora_init).parameters.keys())

    # Filter the kwargs: keep only keys that exist in valid_params
    clean_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}

    # Call the original __init__ with the sanitized arguments
    _orig_lora_init(self, *args, **clean_kwargs)


# Apply the patch
LoraConfig.__init__ = _robust_lora_init


from reward_func import (
    boxed_and_answer_tags_format_reward,
    correctness_reward_func,
    correctness_reward_func_math,
    countdown_reward_func,
    int_reward_func,
    soft_format_reward_func,
    strict_format_reward_func,
    sudoku_reward_func,
    xmlcount_reward_func,
    block_format_reward,
)

from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from trl import ModelConfig, TrlParser

from b1.trainers.diffu_grpo_config import DiffuGRPOConfig

from b1.trainers.diffu_grpo_trainer import DiffuGRPOTrainer
from b1.trainers.eval_callback import AccuracyEvalCallback
from b1.trainers.rev_grpo_ref_pol_trainer import RevDiffuRefPolGRPOTrainer
from b1.trainers.rev_grpo_trainer import RevDiffuGRPOTrainer
from b1.trainers.rev_grpo_trainer_psr import RevPSRDiffuGRPOTrainer
from b1.trainers.gdpo_trainer import GDPOTrainer

from b1.eval.countdown import CTDDataset
from b1.eval.gsm8k import GSM8KDataset
from b1.eval.math500 import MATH500Dataset
from b1.eval.sudoku import SudokuDataset


# Evaluation data from eval/*.py
DATASET_MAP = {
    "gsm8k": GSM8KDataset,
    "math": MATH500Dataset,
    "countdown": CTDDataset,
    "sudoku": SudokuDataset,
}
# Quick validation random subsample sizes
SUB_SAMPLE_MAP = {
    "gsm8k": 100,
    "math": 100,
    "countdown": 100,
    "sudoku": 100,
}


def is_main_process():
    return not dist.is_initialized() or dist.get_rank() == 0


def main(grpo_config, model_config):
    # Set seed for reproducibility
    set_random_seed(grpo_config.seed)
    set_trainer_type(grpo_config.trainer_type)

    # Load dataset and reward functions
    val_dataset = None
    if grpo_config.dataset == "gsm8k":
        dataset = get_gsm8k_questions("train")
        # Format reward + correctness reward
        reward_functions = [
            xmlcount_reward_func,
            int_reward_func,
            correctness_reward_func,
        ]

    elif grpo_config.dataset == "math":
        dataset = get_math_questions("train")
        # Format reward + correctness reward
        reward_functions = [
            correctness_reward_func_math,
            boxed_and_answer_tags_format_reward,
        ]

    elif grpo_config.dataset == "countdown":
        dataset = get_countdown_questions("train")
        # Small data for quick test
        reward_functions = [
            countdown_reward_func,
        ]

    elif grpo_config.dataset == "sudoku":
        dataset = get_sudoku_questions()
        # Small data for quick test
        reward_functions = [
            sudoku_reward_func,
        ]

    dataset = dataset.shuffle(seed=grpo_config.seed)

    # Leave last 500 for validation
    if grpo_config.dataset in ["countdown", "sudoku"]:
        train_set = dataset.select(range(0, len(dataset) - 500))
    else:
        train_set = dataset

    # Set up device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 4 bit quantization configuration
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    try:
        import transformers.modeling_utils as _tmu

        _orig_caching_allocator_warmup = _tmu.caching_allocator_warmup

        def _safe_caching_allocator_warmup(model_to_load, *args, **kwargs):
            try:
                if getattr(model_to_load, "_tp_plan", None) is None:
                    # prefer setting an instance attribute to avoid changing
                    # class semantics for other models
                    model_to_load._tp_plan = []
            except Exception:
                # If we can't set it, fall through and let the original
                # function handle or raise a clearer error
                pass
            return _orig_caching_allocator_warmup(model_to_load, *args, **kwargs)

        _tmu.caching_allocator_warmup = _safe_caching_allocator_warmup
    except Exception:
        # If monkeypatching fails for any reason, continue without it and
        # allow the original error to surface (so the user can see the
        # underlying issue). This keeps behavior deterministic.
        pass

    # Load model and tokenizer - use SFT path if on top of SFT
    model = AutoModel.from_pretrained(
        grpo_config.model_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        quantization_config=bnb_config,
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(
        grpo_config.model_path, trust_remote_code=True
    )
    tokenizer.pad_token = tokenizer.eos_token
    model.config.use_cache = False

    val_dataset = DATASET_MAP[grpo_config.dataset](
        tokenizer,
        subsample=SUB_SAMPLE_MAP[grpo_config.dataset],
        num_examples=0,
        add_reasoning=True,  # prefill for all models
    )

    # Configure LoRA for parameter-efficient fine-tuning
    peft_config = LoraConfig(
        r=model_config.lora_r,
        lora_alpha=model_config.lora_alpha,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "up_proj",
            "down_proj",
            "gate_proj",
        ],
        task_type="CAUSAL_LM",
        lora_dropout=model_config.lora_dropout,
    )
    # Compatibility patch: some versions of transformers call
    # Trainer._get_train_sampler(train_dataset) while older TRL's
    # GRPOTrainer defines _get_train_sampler(self) without the dataset
    # argument. If that's the case, wrap the original method so it
    # accepts the extra parameter and forwards to the implementation.
    try:
        import trl.trainer.grpo_trainer as _grpo_mod

        orig = getattr(_grpo_mod.GRPOTrainer, "_get_train_sampler", None)

        if orig is not None:

            def _wrapped_get_train_sampler(self, train_dataset=None):
                # The original implementation may only expect (self,), so
                # call it accordingly. If it already accepts train_dataset,
                # Python will bind it fine when orig is the function object
                try:
                    return orig(self, train_dataset)
                except TypeError:
                    # Fallback to calling without the extra argument
                    return orig(self)

            _grpo_mod.GRPOTrainer._get_train_sampler = _wrapped_get_train_sampler
    except Exception:
        # If monkeypatching fails, let the error surface later so the
        # user can see the underlying incompatibility.
        pass
    if is_main_process():
        print("Trainer type is: ", grpo_config.trainer_type)

    # Initialize and run trainer
    if grpo_config.trainer_type == "wll_d1_neg" or grpo_config.trainer_type == "b1_wll":
        # NSR + PSR + d1 objective
        trainer = RevDiffuGRPOTrainer(
            args=grpo_config,
            model=model,
            processing_class=tokenizer,
            peft_config=peft_config,
            reward_funcs=reward_functions,
            train_dataset=train_set,
            eval_dataset=val_dataset,
            callbacks=[
                AccuracyEvalCallback(
                    val_dataset,
                    tokenizer=tokenizer,
                    gen_length=grpo_config.max_completion_length,
                    temperature=0.0,
                    steps=grpo_config.diffusion_steps,
                    block_length=grpo_config.block_length,
                    batch_size=grpo_config.per_device_eval_batch_size,
                )
            ],
        )
    elif grpo_config.trainer_type == "wll_d1_pos_only":
        trainer = RevPSRDiffuGRPOTrainer(
            args=grpo_config,
            model=model,
            processing_class=tokenizer,
            peft_config=peft_config,
            reward_funcs=reward_functions,
            train_dataset=train_set,
            eval_dataset=val_dataset,
            callbacks=[
                AccuracyEvalCallback(
                    val_dataset,
                    tokenizer=tokenizer,
                    gen_length=grpo_config.max_completion_length,
                    temperature=0.0,
                    steps=grpo_config.diffusion_steps,
                    block_length=grpo_config.block_length,
                    batch_size=grpo_config.per_device_eval_batch_size,
                )
            ],
        )
    elif grpo_config.trainer_type == "d1" or grpo_config.trainer_type == "b1_d1":
        trainer = DiffuGRPOTrainer(
            args=grpo_config,
            model=model,
            processing_class=tokenizer,
            peft_config=peft_config,
            reward_funcs=reward_functions,
            train_dataset=train_set,
            eval_dataset=val_dataset,
            callbacks=[
                AccuracyEvalCallback(
                    val_dataset,
                    tokenizer=tokenizer,
                    gen_length=grpo_config.max_completion_length,
                    temperature=0.0,
                    steps=grpo_config.diffusion_steps,
                    block_length=grpo_config.block_length,
                    batch_size=grpo_config.per_device_eval_batch_size,
                )
            ],
        )
    elif grpo_config.trainer_type == "wll_d1_neg_ref":
        # add reference policy regularisation for
        trainer = RevDiffuRefPolGRPOTrainer(
            args=grpo_config,
            model=model,
            processing_class=tokenizer,
            peft_config=peft_config,
            reward_funcs=reward_functions,
            train_dataset=train_set,
            eval_dataset=val_dataset,
            callbacks=[
                AccuracyEvalCallback(
                    val_dataset,
                    tokenizer=tokenizer,
                    gen_length=grpo_config.max_completion_length,
                    temperature=0.0,
                    steps=grpo_config.diffusion_steps,
                    block_length=grpo_config.block_length,
                    batch_size=grpo_config.per_device_eval_batch_size,
                )
            ],
        )
    elif grpo_config.trainer_type == "gdpo" or grpo_config.trainer_type == "b1_gdpo":
        trainer = GDPOTrainer(
            args=grpo_config,
            model=model,
            processing_class=tokenizer,
            peft_config=peft_config,
            reward_funcs=reward_functions,
            train_dataset=train_set,
            eval_dataset=val_dataset,
            callbacks=[
                AccuracyEvalCallback(
                    val_dataset,
                    tokenizer=tokenizer,
                    gen_length=grpo_config.max_completion_length,
                    temperature=0.0,
                    steps=grpo_config.diffusion_steps,
                    block_length=grpo_config.block_length,
                    batch_size=grpo_config.per_device_eval_batch_size,
                )
            ],
        )
    else:
        raise Exception("Not know trainer type")

    if is_main_process():
        wandb.init(project=grpo_config.wandb_project, name=grpo_config.run_name)

    trainer.train()


if __name__ == "__main__":
    parser = TrlParser((DiffuGRPOConfig, ModelConfig))
    grpo_config, model_config = parser.parse_args_and_config()
    grpo_config.remove_unused_columns = False
    grpo_config.label_names = ["completion_ids"]
    main(grpo_config=grpo_config, model_config=model_config)
