import os
from torch.utils.data import DataLoader

os.environ["TOKENIZERS_PARALLELISM"] = "false"  # 或 "true"

# Load model directly
import argparse
import torch

from transformers import AutoTokenizer
from model.buddy_model import BuddyForCausalLM

from peft import PeftModel
from utils.training_utils import extract_policy_state, CausalLMDataCollator, GRPOBudgetTrainer, GRPOConfig

from utils.dataset import load_instruction_dataset
import wandb

import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(20250925)

wandb.init("buddy", mode="disabled")


def parse_args():
    parser = argparse.ArgumentParser(description='Tuning Pruned LLM')

    # Model Type&Path
    parser.add_argument('--base_model', type=str, default="baffo32/decapoda-research-llama-7B-hf",
                        help='base model name')
    parser.add_argument('--data_name', type=str, default="openbookqa", help='data name')
    parser.add_argument('--data_path', type=str, default="openbookqa", help='data path')
    parser.add_argument('--output_dir', type=str, default="./tune_log/openbookqa/deputy/", help='output directory')
    parser.add_argument('--lambda_reg', type=float, default=1.0, help='lambda_reg')
    parser.add_argument('--sensitivity_type', type=str, default="taylor", help='sensitivity_type')
    parser.add_argument('--sensitivity_path', type=str,
                        default="utils/sensitivity/llama2_output/ppl/all_ppl_unsorted.csv",
                        help='sensitivity_path')

    # Training Hyperparameters
    parser.add_argument('--batch_size', type=int, default=4, help='batch size')
    parser.add_argument('--num_epochs', type=int, default=2, help='number of epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--cutoff_len', type=int, default=1024, help='cutoff length')
    args = parser.parse_args()
    torch_version = int(torch.__version__.split('.')[1])
    args.torch_version = torch_version

    return args


def load_model(args):
    # load pretrain model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    model = BuddyForCausalLM.from_pretrained(args.base_model)

    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"

    if args.sensitivity_type != "None":
        from utils.sensitivity_utils import read_pre_sensitivity
        pre_sensitivity = read_pre_sensitivity(
            path=args.sensitivity_path,
            type=args.sensitivity_type
        )
        model.set_sensitivity(pre_sensitivity, args.lambda_reg)

    if device == 'cuda':
        model.to(torch.bfloat16)

    model = PeftModel.from_pretrained(model, args.output_dir)

    model = model.to(device)
    return model, tokenizer


def main(args):
    model, tokenizer = load_model(args)
    model.train()

    # budget_predictor = model.model.model.budget_predictor
    #
    # input_text = [
    #     "Hello, this is a simple input",
    #     "Hello, this is a simple input",
    # ]
    # inputs = tokenizer(input_text, return_tensors="pt", padding=True, max_length=20).to(device)
    # print(inputs)
    #
    # state = extract_policy_state(model, inputs["input_ids"], inputs["attention_mask"])  # [B, T, H]
    #
    # print(state.shape)
    #
    # logits = budget_predictor(state)
    # print("logits=", logits)
    # dist = budget_predictor.dist(state)
    # print("dist=", dist)
    # greedy_k = budget_predictor.greedy_k(state)
    # print("greedy_k=", greedy_k)
    # k_idx, logp, ent = budget_predictor.sample_k(state, sample_shape=torch.Size([4]))  # [B,g]
    # print("k_idx=", k_idx, "logp=", logp, "ent=", ent)

    train_data, val_data = load_instruction_dataset(
        name=args.data_name,
        path=args.data_path,
        tokenizer=tokenizer,
        max_length=args.cutoff_len
    )

    collate_fn = CausalLMDataCollator(tokenizer, device=device)
    train_loader = DataLoader(
        train_data,
        batch_size=args.batch_size,
        collate_fn=collate_fn
    )

    config = GRPOConfig(
        group_size=4,
        entropy_coef=0.01,
        cost_coef=0.05,
        perf_scale=5.0,
        lr=args.learning_rate,
        device=device,
        grpo_epochs=8,  # <-- 对每批数据进行4次优化
        clip_eps=0.2  # <-- PPO-Clip 的截断范围
    )
    trainer = GRPOBudgetTrainer(model, tokenizer, config)
    trainer.train(dataloader=train_loader, num_epochs=args.num_epochs)
    trainer.save(args.output_dir)


if __name__ == "__main__":
    args = parse_args()
    main(args)
