import json
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, TrlParser, ModelConfig, SFTConfig, ScriptArguments
import torch

parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()


raw_data = json.load(open(script_args.dataset_name))


ANALYSIS_PROMPT_TEMPLATE = """
            You are an expert in Operations Research (OR). You will be given an optimization problem and (optionally) a step-by-step solution, which may or may not include code.

            Task:
            Review the solution. Analyze each applicable part in order. Be concise — only highlight critical errors or omissions. Skip any section if the input doesn't contain it (e.g., no code → skip Code Analysis).

            Evaluate in this order:

            1. Variable Definitions
            2. Objective Function & Constraints
            3. Code Implementation (if provided)
            4. Final Answer / Output

            Question:
            {Question}

            Solution Steps:
            {Solution}

            Output Format (be brief and precise):

            1. Variable Definition Analysis
            - Intent: [e.g., Define decision variables]
            - Analysis: [Only note missing, redundant, or misdefined variables]
            - Judgement: [Correct/Incorrect]

            2. Objective & Constraint Analysis
            - Intent: [e.g., Formulate model]
            - Objective: [Correct? Brief reason if wrong]
            - Constraints: [Missing/incorrect? List only key issues]
            - Judgement: [Correct/Incorrect]

            3. Code Analysis (Skip if no code)
            - Intent: Implement model in Pyomo/Python
            - Analysis: [Only flag mismatches: missing vars/constraints, wrong indexing, type errors]
            - Judgement: [Correct/Incorrect or Skipped]

            4. Final Answer Analysis
            - Intent: [e.g., Report solution or error]
            - Analysis: [Plausible? Error meaningful? Root cause if wrong]
            - Judgement: [Correct/Incorrect]

            Corrected Step (Only if any part above is Incorrect)
            - [Rewrite only the first incorrect section — e.g., fix constraints or variables — in full, clearly labeled.]
            """

corpus = []
for d in raw_data:
    prompt = d["question"]
    question = d['question']
    process = "\n\n".join(d['process'])


    corpus.append({
        "prompt": ANALYSIS_PROMPT_TEMPLATE.format(Question=question, Solution=process),
        "completion": d['critic']
    })

dataset = Dataset.from_list(corpus)

model = AutoModelForCausalLM.from_pretrained(
    model_config.model_name_or_path,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2"
)

tokenizer = AutoTokenizer.from_pretrained(
    model_config.model_name_or_path,
    trust_remote_code=True
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset
)

trainer.train()

trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)


