import os
import math
import shutil
import re
import pandas as pd
import torch
import bitsandbytes as bnb
from tqdm import tqdm
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    set_seed,
    TrainerCallback,
)
from peft import LoraConfig, get_peft_model
import wandb

# ======== Global Constants ======== #
SEED = 202412
set_seed(SEED)

PROJ_DIR = 'math_lora_finetune-r=128'
# Replace the local GSM8K train/test parquet file paths below
TRAIN_DATA_PATH = 'path'
TEST_DATA_PATH = 'path'

TRAINING_OUTPUT_DIR = f"/root/path/{PROJ_DIR}/path-qa-peft"
MERGED_MODEL_PATH = f"/root/path/{PROJ_DIR}/path-qa-peft-merged"
MODEL_NAME = "path"

# Global settings
OBJS_DIR = "path"
os.makedirs(OBJS_DIR, exist_ok=True)
os.makedirs(TRAINING_OUTPUT_DIR, exist_ok=True)
os.makedirs(MERGED_MODEL_PATH, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Predefined global MODEL_DTYPE to record main model precision
MODEL_DTYPE = torch.float32


def load_model_and_tokenizer(model_name: str):
    """Load model and tokenizer, and check model data type"""
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        use_cache=False,
    )
    model = model.to("cuda")

    param_dtypes = set()
    for name, param in model.named_parameters():
        param_dtypes.add(param.dtype)
    
    main_dtype = max(param_dtypes, key=lambda x: sum(p.dtype == x for p in model.parameters()))
    print(f"[INFO] Main model dtype: {main_dtype}")
    
    global MODEL_DTYPE
    MODEL_DTYPE = main_dtype
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.add_eos_token = True
    
    return model, tokenizer


def get_max_length(model: torch.nn.Module) -> int:
    """Get the model's maximum input length"""
    conf = model.config
    max_length = None
    for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
        max_length = getattr(model.config, length_setting, None)
        if max_length:
            break
    if not max_length:
        max_length = 1024
    return max_length


def create_prompt_formats(sample: dict):
    """Simplified prompt format with fewer special characters"""
    system_prompt = "You are a math expert. Solve the problem step by step, and put your final answer within \\boxed{}."
    
    return {
        "text": f"### Instruction: {system_prompt}\n### Question: {sample['question']}\n### Answer: {sample['answer']}",
        "question": sample["question"],
        "answer": sample["answer"]
    }


def create_model_inputs(example, tokenizer, max_length=1024):
    """Fix label alignment issues and ensure that only the answer part is considered when calculating loss"""
    full_text = example["text"]
    
    encoding = tokenizer(
        full_text,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    
    labels = encoding["input_ids"].clone()

    # Handle padding token (set pad token label as -100)
    labels[labels == tokenizer.pad_token_id] = -100

    return {
        "input_ids": encoding["input_ids"].flatten(),
        "attention_mask": encoding["attention_mask"].flatten(),
        "labels": labels.flatten(),
    }


def preprocess_dataset(dataset: Dataset, tokenizer, max_length: int, seed: int) -> Dataset:
    """
    Preprocess dataset: Format prompts and encode into model inputs.
    """
    dataset = dataset.map(create_prompt_formats)
    dataset = dataset.map(
        lambda example: create_model_inputs(example, tokenizer, max_length),
        remove_columns=dataset.column_names
    )
    dataset = dataset.shuffle(seed=seed)
    return dataset


def print_trainable_parameters(model: torch.nn.Module):
    """
    Count trainable parameters.
    """
    trainable_params = 0
    total_params = 0
    for param in model.parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"Trainable parameters: {trainable_params} / {total_params} ({100 * trainable_params / total_params:.2f}%)")


class SingleCheckpointCallback(TrainerCallback):
    """
    Custom callback to keep only one checkpoint.
    If the current eval_loss is lower than previous, delete the previous checkpoint and save the current one.
    """
    def __init__(self, output_dir, trainer):
        self.output_dir = output_dir
        self.best_loss = float("inf")
        self.current_checkpoint = None
        self.trainer = trainer  # Pass trainer object directly

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if not metrics:
            return
        
        current_loss = metrics.get("eval_loss", float("inf"))
        if current_loss < self.best_loss:
            self.best_loss = current_loss
            
            # Delete previous checkpoint
            if self.current_checkpoint:
                print(f"Deleting previous checkpoint: {self.current_checkpoint}")
                shutil.rmtree(self.current_checkpoint)
            
            # Save current checkpoint
            checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{state.global_step}")
            self.current_checkpoint = checkpoint_dir
            print(f"Saving new best checkpoint to: {checkpoint_dir}")
            
            # Save model and tokenizer using trainer object
            self.trainer.model.save_pretrained(checkpoint_dir)
            self.trainer.tokenizer.save_pretrained(checkpoint_dir)


def train_and_evaluate(model, tokenizer, dataset_train, dataset_validation, output_dir):
    """
    Train and evaluate the model.
    """
    from peft import LoraConfig, get_peft_model

    # Using LoRA configuration
    lora_config = LoraConfig(
        r=64,  # LoRA rank
        lora_alpha=128,  # LoRA alpha parameter
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Apply to in_proj and out_proj
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM"
    )
    # Prepare model and apply LoRA
    model = get_peft_model(model, lora_config)

    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Specify parameters to adjust
    trainable_keywords = ["lora"]

    # Make selected parameters trainable
    for name, param in model.named_parameters():
        if any(keyword in name for keyword in trainable_keywords):
            param.requires_grad = True
    print(model)
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"Trainable: {name}")
    print_trainable_parameters(model)

    ## MEMORY OPTIMIZATION ##
    use_fp16 = (MODEL_DTYPE == torch.float16)
    use_bf16 = (MODEL_DTYPE == torch.bfloat16)

    # Compute steps per epoch
    per_device_train_batch_size = 8
    gradient_accumulation_steps = 4
    steps_per_epoch = math.ceil(len(dataset_train) / (per_device_train_batch_size * gradient_accumulation_steps))
    
    # Ensure logging_steps and eval_steps meet the multiple relation
    logging_steps = max(1, int(0.05 * steps_per_epoch))  # Log loss every 0.05 epoch
    eval_steps = max(1, int(0.5 * steps_per_epoch))    # Evaluate every 0.5 epoch
    
    if eval_steps % logging_steps != 0:
        eval_steps = logging_steps * (eval_steps // logging_steps + 1)
    
    print(f"Steps per epoch: {steps_per_epoch}")
    print(f"Logging steps: {logging_steps}")
    print(f"Eval steps: {eval_steps}")

    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        logging_dir=os.path.join(output_dir, "logs"),
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=6,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_ratio=0.05,
        learning_rate=1e-6,
        bf16=use_bf16,
        fp16=use_fp16,
        fp16_full_eval=use_fp16,
        save_total_limit=1,  # Keep only one checkpoint
        evaluation_strategy="steps",
        eval_steps=eval_steps,
        logging_strategy="steps",
        logging_steps=logging_steps,
        save_strategy="steps",
        save_steps=eval_steps,  # Save checkpoint after each eval
        num_train_epochs=5,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        seed=SEED,
        report_to="wandb",
        max_grad_norm=0.5,
        optim="adamw_bnb_8bit",
        weight_decay=0.01,
        lr_scheduler_type="cosine",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset_train,
        eval_dataset=dataset_validation,
        tokenizer=tokenizer,
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )

    # Pass the trainer object to SingleCheckpointCallback
    checkpoint_callback = SingleCheckpointCallback(output_dir, trainer)
    trainer.add_callback(checkpoint_callback)

    # Start training
    trainer.train()

    # Save the best model
    best_model_path = os.path.join(output_dir, "best_model")
    os.makedirs(best_model_path, exist_ok=True)
    model.save_pretrained(best_model_path)
    tokenizer.save_pretrained(best_model_path)
    print(f"Best model saved to: {best_model_path}")

    # Merge LoRA weights into the original model
    model = model.merge_and_unload()
    model.save_pretrained(MERGED_MODEL_PATH)
    tokenizer.save_pretrained(MERGED_MODEL_PATH)
    print(f"Merged model saved to: {MERGED_MODEL_PATH}")

    return model, tokenizer


def generate_answer(question, model, tokenizer, device):
    """Generate an answer and put the final answer inside \\boxed{}"""
    #zero shot
    input_text = f"You are a math expert. Solve the problem step by step, and put your final answer within \\boxed{{}}. Question: {question}"
    inputs = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True).input_ids.to(device)
    outputs = model.generate(
        inputs,
        max_new_tokens=200,
        use_cache=True,
    )
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract content from \\boxed{} as the final answer
    boxed_answer = re.search(r'\\boxed\{(.*?)\}', full_response)
    if boxed_answer:
        final_answer = boxed_answer.group(1)  # Extract content inside \\boxed{}
    else:
        # If \\boxed{} is not found, try extracting the last number as the answer
        numbers = re.findall(r'\d+', full_response)
        final_answer = numbers[-1] if numbers else "-1"
    
    return full_response, final_answer


def inference(model, tokenizer, device, test_data_path, output_csv_file):
    """Inference function"""
    df = pd.read_parquet(test_data_path)
    inference_results = pd.DataFrame(columns=["question", "answer", "predicted_answer", "final_answer"])

    for index, row in tqdm(df.iterrows(), total=len(df), desc="Generating Answers"):
        question = row["question"]
        answer = f"{row['answer']}"
        predicted_answer, final_answer = generate_answer(question, model, tokenizer, device)
        
        inference_results = pd.concat(
            [inference_results, pd.DataFrame([{"question": question, "answer": answer, "predicted_answer": predicted_answer, "final_answer": final_answer}])],
            ignore_index=True,
        )
        inference_results.to_csv(output_csv_file, index=False)

    print(f"Inference results saved to {output_csv_file}")


def main():
    # 1) Read GSM8K dataset train.parquet and test.parquet
    train_df = pd.read_parquet(TRAIN_DATA_PATH)
    test_df = pd.read_parquet(TEST_DATA_PATH)
    
    # 2) Convert to HuggingFace Dataset
    train_dataset = Dataset.from_pandas(train_df)
    test_dataset = Dataset.from_pandas(test_df)

    # 3) Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(MODEL_NAME)

    # 4) Preprocess dataset
    max_length = 128
    dataset_train = preprocess_dataset(train_dataset, tokenizer, max_length, SEED)
    dataset_val = preprocess_dataset(test_dataset, tokenizer, max_length, SEED)

    print(f"Training data size: {len(dataset_train)}")
    print(f"Validation data size: {len(dataset_val)}")

    # 5) Train and evaluate
    model, tokenizer = train_and_evaluate(model, tokenizer, dataset_train, dataset_val, TRAINING_OUTPUT_DIR)

    # 6) Inference
    output_csv_file = "path"
    inference(model, tokenizer, DEVICE, TEST_DATA_PATH, output_csv_file)


if __name__ == "__main__":
    main()
