import json
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset
import torch
from transformers import TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
import re

json_files = [
    'xxx.json'
]

all_data = []

for file in json_files:
    with open(file, 'r') as f:
        data = json.load(f)
        all_data.extend(data)

df = pd.DataFrame(all_data)

df['text'] = df['text'].astype(str)
df['CoT'] = df['CoT'].astype(str)
df['Label'] = df['Label'].astype(str) 
df = df.dropna(subset=['text', 'CoT', 'Label'])

model_path = "Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True,
    padding_side="right",  
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    sarcasm_definition = "(1=sarcasm: contains features like surface praise with underlying criticism, contextual incongruity, exaggerated contrast, etc. 0=not sarcasm)"
    prompts = [
        f"### Sarcasm Classification Task:\nAnalyze the sarcasm of this text step by step\nText content: {text}\nReasoning steps: {cot}\nSarcasm Label: {label}(0-1 as defined in {sarcasm_definition})"
        for text, cot, label in zip(examples['text'], examples['CoT'], examples['Label'])
    ]
    
    tokenized_inputs = tokenizer(
        prompts,
        padding="max_length", 
        truncation=True,
        max_length=2048,
        return_tensors="pt"
    )
    
    tokenized_inputs["labels"] = tokenized_inputs["input_ids"].clone()
    return tokenized_inputs

train_dataset = Dataset.from_pandas(df)

train_dataset = train_dataset.map(tokenize_function, batched=True, batch_size=8)

model = AutoModelForCausalLM.from_pretrained(
    "Llama-3.1-8B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    use_cache=False, 
)
model.gradient_checkpointing_enable()  

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    learning_rate=2e-5,
    bf16=True,
    logging_steps=10,
    save_steps=5000,
    remove_unused_columns=True,
    gradient_checkpointing=True, 
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

train_result = trainer.train()

trainer.save_model("best_model_llama_8B_sarcasm")