"""
SFT fine-tuning script for choices between two gambles, using a fixed-utility ground truth.
Supports multiple economic utility functions configurable via command-line arguments.
"""

import os
import argparse
from typing import Dict, Callable

import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig

import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
from utility_functions import *

UTILITY_FUNCTIONS: Dict[str, Callable] = {
    "linear": linear_utility,
    "power": power_utility,
    "quadratic": quadratic_utility,
    "crra": crra_utility,
    "cara": cara_utility,
    "isoelastic": isoelastic_utility,
    "prospect_theory": prospect_theory_value,
    "hara": hara_utility,
    "expo_power_saha": expo_power_utility_saha,
}

UTILITY_PARAMS = {
    "linear": [],
    "power": ["alpha"],
    "quadratic": ["b"],
    "crra": ["theta"],
    "cara": ["alpha"],
    "isoelastic": ["theta", "lam"],
    "prospect_theory": ["alpha", "beta", "lam", "reference_point"],
    "hara": ["alpha", "beta", "gamma"],
    "expo_power_saha": ["alpha", "theta"],
}


def calculate_preference(row: pd.Series, util_func: Callable, params: Dict) -> str:
    EU_A = row['p_a'] * util_func(row['reward1_a'], **params) + (1 - row['p_a']) * util_func(row['reward2_a'], **params)
    EU_B = row['p_b'] * util_func(row['reward1_b'], **params) + (1 - row['p_b']) * util_func(row['reward2_b'], **params)
    return "A" if EU_A > EU_B else "B"

def build_sft_dataset(df: pd.DataFrame, util_name: str, params: dict) -> pd.DataFrame:
    df = df.copy()

    util_func = UTILITY_FUNCTIONS[util_name]
    df["preferred_choice"] = df.apply(
        lambda row: calculate_preference(row, util_func, params), axis=1
    )

    def format_prompt(text: str) -> str:
        return (
            "You are an economic decision-making agent. "
            "Analyze the options and reply with your choice as a single letter: A or B.\n\n"
            f"Question:\n{text}\n\n"
            "Answer:"
        )

    df["prompt"] = df["prompt_text"].astype(str).apply(lambda x: f"{format_prompt(x)}\n")

    df["completion"] = df["preferred_choice"]

    return df[["prompt", "completion", "preferred_choice"]]


@torch.inference_mode()
def evaluate_model(model, tokenizer, dataset) -> float:
    model.eval()
    correct, total = 0, 0
    for item in dataset:
        prompt = item["prompt"]
        ground_truth = item["preferred_choice"]

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(**inputs, max_new_tokens=3, pad_token_id=tokenizer.eos_token_id)
        response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()

        if response and response[0].upper() in ["A", "B"]:
            if response[0].upper() == ground_truth:
                correct += 1
        total += 1
    return correct / total if total > 0 else 0.0


def main():
    parser = argparse.ArgumentParser(description="SFT fine-tuning with a wide range of fixed-utility preferences.")

    # --- Data and Model Args ---
    parser.add_argument("--gpu_id", type=str, default="0")
    parser.add_argument("--data_csv", type=str, default = "")
    parser.add_argument("--output_dir", type=str, default="")
    parser.add_argument("--model_id", type=str, default='')

    # --- Utility Function Args ---
    parser.add_argument("--utility", type=str, choices=list(UTILITY_FUNCTIONS.keys()), default="crra")
    parser.add_argument("--alpha", type=float, default=2.0)
    parser.add_argument("--beta", type=float, default=0.88)
    parser.add_argument("--theta", type=float, default=2.0)
    parser.add_argument("--lam", type=float, default=2.25)
    parser.add_argument("--gamma", type=float, default=0.7)
    parser.add_argument("--b", type=float, default=0.1)
    parser.add_argument("--reference_point", type=float, default=0.0)

    # --- Training Hyperparameters ---
    parser.add_argument("--test_size", type=float, default=0.1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--per_device_train_batch_size", type=int, default=2)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
    parser.add_argument("--learning_rate", type=float, default=5e-6)
    parser.add_argument("--num_train_epochs", type=float, default=2.0)
    parser.add_argument("--lora_r", type=int, default=16)
    parser.add_argument("--lora_alpha", type=int, default=32)

    args = parser.parse_args()
    if args.gpu_id:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    os.makedirs(args.output_dir, exist_ok=True)

    util_params = {p: getattr(args, p) for p in UTILITY_PARAMS[args.utility]}
    df = pd.read_csv(args.data_csv)
    pref_df = build_sft_dataset(df, args.utility, util_params)
    train_df, test_df = train_test_split(pref_df, test_size=args.test_size, random_state=args.seed)
    dataset = DatasetDict({"train": Dataset.from_pandas(train_df), "eval": Dataset.from_pandas(test_df)})

    quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
    model = AutoModelForCausalLM.from_pretrained(args.model_id, quantization_config=quant_config, device_map="auto", trust_remote_code=True)
    model.config.use_cache = False
    tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    peft_config = LoraConfig(
        r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )

    sft_config = SFTConfig(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_train_epochs=args.num_train_epochs,
        learning_rate=args.learning_rate,
        logging_steps=10,
        save_strategy="steps",
        save_steps=500,
        seed=args.seed,
        weight_decay=0.0,
        bf16=False,
        fp16=True,
        packing=False,
    )

    trainer = SFTTrainer(
        model=model,
        # tokenizer=tokenizer,
        train_dataset=dataset["train"],
        eval_dataset=dataset["eval"],
        args=sft_config,
        peft_config=peft_config,
        # dataset_text_field="prompt",   # SFT uses prompt+response concatenation
        # max_seq_length=512,
    )

    trainer.train()
    trainer.save_model(args.output_dir)

    accuracy = evaluate_model(trainer.model, tokenizer, dataset["eval"])
    print(f"\n✅ Final Accuracy on Test Set: {accuracy:.2%}")


if __name__ == "__main__":
    main()
