import os
import argparse
from typing import Dict, Callable

import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import DPOTrainer, DPOConfig

import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
from utility_functions import *

def calculate_preference(row: pd.Series, util_func: Callable, params: Dict) -> str:
    """Computes expected utility for gambles A and B and returns the preferred option."""
    EU_A = row['p_a'] * util_func(row['reward1_a'], **params) + (1 - row['p_a']) * util_func(row['reward2_a'], **params)
    EU_B = row['p_b'] * util_func(row['reward1_b'], **params) + (1 - row['p_b']) * util_func(row['reward2_b'], **params)
    return "A" if EU_A > EU_B else "B"




def build_dpo_dataset(df: pd.DataFrame, util_name: str, params: Dict) -> pd.DataFrame:
    """Builds the DPO dataset with (prompt, chosen, rejected) columns."""
    df = df.copy()
    utility_functions = {
        "linear": linear_utility, "power": power_utility, "quadratic": quadratic_utility,
        "crra": crra_utility, "cara": cara_utility, "isoelastic": isoelastic_utility,
        "prospect_theory": prospect_theory_value, "hara": hara_utility,
        "expo_power_saha": expo_power_utility_saha,
    }
    util_func = utility_functions[util_name]

    df["preferred_choice"] = df.apply(lambda row: calculate_preference(row, util_func, params), axis=1)

    def format_prompt(text: str) -> str:
        return (
            "You are an economic decision-making agent. "
            "Analyze the options and reply with your choice as a single letter: A or B.\n\n"
            f"Question:\n{text}\n\n"
            "Answer:"
        )

    df["prompt"] = df["prompt_text"].astype(str).apply(format_prompt)
    df["chosen"] = df["preferred_choice"]
    df["rejected"] = df["preferred_choice"].map(lambda x: "B" if x == "A" else "A")
    return df[["prompt", "chosen", "rejected", "preferred_choice"]]


@torch.inference_mode()
def evaluate_model(model, tokenizer, dataset) -> float:
    """Calculates the accuracy of the model on the evaluation set."""
    model.eval()
    correct = 0
    total = 0
    for item in dataset:
        prompt = item["prompt"]
        ground_truth = item["preferred_choice"]
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        outputs = model.generate(**inputs, max_new_tokens=3, pad_token_id=tokenizer.eos_token_id)
        
        response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
        
        if response and response[0].upper() in ["A", "B"]:
            predicted_choice = response[0].upper()
            if predicted_choice == ground_truth:
                correct += 1
        total += 1
        
    return correct / total if total > 0 else 0.0


def main():
    parser = argparse.ArgumentParser(description="DPO fine-tuning with a wide range of fixed-utility preferences.")
    
    # --- Data and Model Args ---
    parser.add_argument("--gpu_id", type=str, default="0", help="The GPU ID(s) to use for training (e.g., '0' or '0,1').")
    parser.add_argument("--data_csv", type=str, default="", help="Path to the source CSV file.")
    parser.add_argument("--output_dir", type=str, default="", help="Directory to save the final model.")
    parser.add_argument("--model_id", type=str, default="", help="Base model ID from Hugging Face Hub.")

    # --- Utility Function Args ---
    utility_choices = ["linear", "power", "quadratic", "crra", "cara", "isoelastic", "prospect_theory", "hara", "expo_power_saha"]
    parser.add_argument("--utility", type=str, default="crra", choices=utility_choices, help="Utility function for ground truth.")
    parser.add_argument("--alpha", type=float, default=2, help="Parameter for Power, CARA, Prospect Theory, HARA, Expo-Power.")
    parser.add_argument("--beta", type=float, default=0.88, help="Parameter for Prospect Theory, HARA.")
    parser.add_argument("--theta", type=float, default=2.0, help="Parameter for CRRA, Isoelastic, Expo-Power.")
    parser.add_argument("--lam", type=float, default=2.25, help="Lambda (loss aversion) for Isoelastic, Prospect Theory.")
    parser.add_argument("--gamma", type=float, default=0.7, help="Gamma parameter for HARA.")
    parser.add_argument("--b", type=float, default=0.1, help="Parameter for Quadratic utility.")
    parser.add_argument("--reference_point", type=float, default=0.0, help="Reference point for Prospect Theory.")

    parser.add_argument("--test_size", type=float, default=0.1, help="Fraction of data to use for evaluation.")
    parser.add_argument("--per_device_train_batch_size", type=int, default=2)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
    parser.add_argument("--learning_rate", type=float, default=5e-6)
    parser.add_argument("--num_train_epochs", type=float, default=2.0)
    parser.add_argument("--dpo_beta", type=float, default=0.1, help="Beta parameter for DPO loss.")
    parser.add_argument("--max_prompt_length", type=int, default=512)
    parser.add_argument("--max_length", type=int, default=1024)
    parser.add_argument("--lora_r", type=int, default=16)
    parser.add_argument("--lora_alpha", type=int, default=32)
    
    args = parser.parse_args()
    if args.gpu_id:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    os.makedirs(args.output_dir, exist_ok=True)


    util_params = {}
    if args.utility == "linear":          util_params = {}
    elif args.utility == "power":         util_params = {"alpha": args.alpha}
    elif args.utility == "quadratic":     util_params = {"b": args.b}
    elif args.utility == "crra":          util_params = {"theta": args.theta}
    elif args.utility == "cara":          util_params = {"alpha": args.alpha}
    elif args.utility == "isoelastic":    util_params = {"theta": args.theta, "lam": args.lam}
    elif args.utility == "prospect_theory": util_params = {"alpha": args.alpha, "beta": args.beta, "lam": args.lam, "reference_point": args.reference_point}
    elif args.utility == "hara":          util_params = {"alpha": args.alpha, "beta": args.beta, "gamma": args.gamma}
    elif args.utility == "expo_power_saha": util_params = {"alpha": args.alpha, "theta": args.theta}

    print(f"Loading data from '{args.data_csv}'...")
    df = pd.read_csv(args.data_csv)
    
    print(f"Generating preference data using utility='{args.utility}' with params={util_params}...")
    pref_df = build_dpo_dataset(df, args.utility, util_params)
    
    train_df, test_df = train_test_split(pref_df, test_size=args.test_size, random_state=args.seed)
    dataset = DatasetDict({"train": Dataset.from_pandas(train_df), "eval": Dataset.from_pandas(test_df)})
    
    print(f"Loading base model: {args.model_id}...")
    quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
    model = AutoModelForCausalLM.from_pretrained(args.model_id, quantization_config=quant_config, device_map="auto", trust_remote_code=True)
    model.config.use_cache = False
    tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

    peft_config = LoraConfig(
        r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )

    dpo_config = DPOConfig(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_train_epochs=args.num_train_epochs,
        learning_rate=args.learning_rate,
        logging_steps=10,
        save_strategy="steps",
        save_steps=500,seed=args.seed,weight_decay=0.0,bf16=False,
        fp16=True,padding_value=tokenizer.pad_token_id)

    trainer = DPOTrainer(
        model,
        args=dpo_config,
        # beta=args.dpo_beta,
        train_dataset=dataset["train"],
        eval_dataset=dataset["eval"],
        # tokenizer=tokenizer,
        peft_config=peft_config,
        # max_prompt_length=args.max_prompt_length,
        # max_length=args.max_length,
    )

    print("Starting DPO training...")
    trainer.train()
    print("Training complete. Saving final LoRA adapter...")
    trainer.save_model(args.output_dir)

    # --- 6. Final Evaluation & Summary ---
    print("Performing final evaluation...")
    accuracy = evaluate_model(trainer.model, tokenizer, dataset["eval"])
    print(f"\n✅ Final Accuracy on Test Set: {accuracy:.2%}")

    summary_path = os.path.join(args.output_dir, "summary.txt")
    with open(summary_path, "w") as f:
        f.write(f"Base Model: {args.model_id}\n")
        f.write(f"Utility Function: {args.utility}\n")
        f.write(f"Utility Parameters: {util_params}\n")
        f.write(f"DPO Beta: {args.dpo_beta}\n")
        f.write(f"Final Test Accuracy: {accuracy:.4f}\n")
    
    print(f"All done. Final model adapters and summary saved in '{args.output_dir}'.")


if __name__ == "__main__":
    main()
