
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any
import argparse 
import os
import sys
import logging
import yaml
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig
from tqdm import tqdm
import torch
from datetime import datetime
import wandb
import json
from datasets import load_dataset
from trl import SFTTrainer

# --- Argument Classes ---
@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to the pre-trained model or model identifier from Hugging Face Hub."}
    )
    tokenizer_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "Path to the tokenizer or tokenizer identifier from Hugging Face Hub. If None, defaults to model_name_or_path."}
    )
    device_map: Optional[str] = field(
        default="auto", metadata={"help": "Device map for model loading. 'auto' will distribute the model across available devices."}
    )
    torch_dtype: Optional[str] = field(
        default="auto", metadata={"help": "Torch dtype for model loading (e.g., 'bfloat16', 'float16', 'auto')."}
    )
    use_auth_token: Optional[bool] = field(
        default=False,
        metadata={"help": "Will use the token given in auth_token to instantiate the model."}
    )
    auth_token: Optional[str] = field(
        default=None, metadata={"help": "The authentication token to use."}
    )

@dataclass 
class LoraArguments:
    alpha: int = field(
        default=64, metadata={"help": "Alpha value for LoRA."}
    )
    dropout: float = field(
        default=0.0, metadata={"help": "Dropout rate for LoRA."}
    )
    rank: int = field(
        default=64, metadata={"help": "Rank for LoRA."}
    )
    bias: str = field(
        default="none", metadata={"help": "Bias for LoRA layers."}
    )
    init_lora_weights: str = field(
        default="gaussian", metadata={"help": "Initialization method for LoRA weights."}
    )

@dataclass
class DataArguments:
    dataset_name: str = field(
        metadata={"help": "Name of the dataset to use for fine-tuning."}
    )
    subset_name: Optional[str] = field(
        default=None, metadata={"help": "Subset of the dataset to use (if applicable)."}
    )
    train_split: Optional[str] = field(
        default='train', metadata={"help": "Path to the training data file (if not using a dataset)."}
    )
    valid_split: Optional[str] = field(
        default='validation', metadata={"help": "Path to the validation data file (if not using a dataset)."}
    )
    input_column: str = field(
        default="question", metadata={"help": "Key in the dataset that contains the prompts."}
    )
    output_column: str = field(
        default="answer", metadata={"help": "Key in the dataset that contains the responses."}
    )
    custom_prompt_file: Optional[str] = field(
        default=None, metadata={"help": "Custom prompt template to use for the dataset."}
    )
    custom_prompt_id: Optional[str] = field(
        default=None, metadata={"help": "ID of the custom prompt to use for the dataset."}
    )

@dataclass
class IFTDataArguments:
    second_dataset_name: str = field(
        default=None, metadata={"help": "Name of the dataset to use for fine-tuning."}
    )
    input_columns: List[str] = field(
        default_factory=lambda: ["question", "context"], metadata={"help": "List of input columns for the dataset."}
    )
    second_subset_name: Optional[str] = field(
        default=None, metadata={"help": "Subset of the dataset to use (if applicable)."}
    )
    second_train_split: Optional[str] = field(
        default='train', metadata={"help": "Path to the training data file (if not using a dataset)."}
    )
    second_output_column: str = field(
        default="answer", metadata={"help": "Key in the dataset that contains the responses."}
    )

@dataclass
class WandbArguments:
    wandb_project: str = field(
        default="finetune_experiment", metadata={"help": "Name of the Weights & Biases project."}
    )
    wandb_run_name: Optional[str] = field(
        default=None, metadata={"help": "Custom name for the Weights & Biases run."}
    )

@dataclass 
class EvalArguments:
    eval_output_dir: str = field(
        default="output", metadata={"help": "Directory to save evaluation outputs."}
    )

# --- Helper Functions ---
def setup_logging(output_dir: str, model_name: str, dataset_name: str):
    safe_model_name = model_name.replace("/", "_")
    safe_dataset_name = dataset_name.replace("/", "_")
    log_filename = f"train_{safe_model_name}_{safe_dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    log_filepath = os.path.join(output_dir, log_filename)
    os.makedirs(output_dir, exist_ok=True)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.FileHandler(log_filepath), logging.StreamHandler(sys.stdout)]
    )
    logging.info(f"Logging to {log_filepath}")

def format_prompts(examples: Dict[str, Any],input_column: str,output_column: str,custom_prompt_template: Optional[str] = None,eos_token: str = "<|endoftext|>",) -> Dict[str, List[str]]:
    phrases = examples[input_column]
    alists = examples[output_column]
    formatted_texts = []

    if custom_prompt_template:
        for phrase, alist in zip(phrases, alists):
            text = custom_prompt_template.format(phrase=phrase, alist_representation=alist) + eos_token
            formatted_texts.append(text)
    else:
        # Default prompt template if no custom one is provided
        default_template = "Translate the following sentence to its ALIST representation: {phrases}. ALIST: {alists}"
        for phrase, alist in zip(phrases, alists):
            text = default_template.format(phrases=phrase, alists=alist) + eos_token
            formatted_texts.append(text)

    return {"text": formatted_texts}

def format_prompts_for_ift(examples: Dict[str, Any], input_columns: List[str], output_column: str, eos_token: str = "<|endoftext|>") -> Dict[str, List[str]]:
    output = examples[output_column]
    formatted_texts = []
    for i in range(len(output)):
        text = f"{', '.join(examples[col][i] for col in input_columns)}.{output[i]}{eos_token}"
        formatted_texts.append(text)
    return {"text": formatted_texts}

def main():
    # --- Argument Parsing ---
    job_id = os.environ.get("JOB_NUM")
    job_id = int(job_id) if job_id and job_id.isdigit() else None
    if job_id:
        logging.info(f"Running evaluation for JOB_NUM: {job_id}")
    else:
        logging.warning("No JOB_NUM found in environment variables")
        sys.exit(1)
    with open("training_hyperparams.yml", "r") as f:
        hyperparams = yaml.safe_load(f)
    
    hyperparam_set = hyperparams[job_id-1]["hyperparam_set"]
    if hyperparam_set is None:
        logging.error(f"Job ID {job_id} not found in hyperparams.yml")
        sys.exit(1)

    logging.info(f"Using hyperparameter set: {hyperparam_set}")

    parser = HfArgumentParser((ModelArguments, TrainingArguments, LoraArguments, DataArguments, IFTDataArguments, WandbArguments, EvalArguments))
    model_args, training_args, lora_args, data_args, ift_data_args, wandb_args, eval_args = parser.parse_dict(hyperparam_set)
    training_args.learning_rate = float(training_args.learning_rate)
    lora_args.dropout = float(lora_args.dropout)

    # --- Setup Logging ---
    setup_logging(eval_args.eval_output_dir, model_args.model_name_or_path, data_args.dataset_name)
    logging.info(f"Model: {model_args.model_name_or_path}")
    logging.info(f"Evaluation output directory: {eval_args.eval_output_dir}")
    logging.info(f"Dataset: {data_args.dataset_name} (subset: {data_args.subset_name})")
    logging.info(f"Custom prompt: {data_args.custom_prompt_file} (ID: {data_args.custom_prompt_id})")
    logging.info(f"Training arguments: {training_args}")

    # --- Initialize W&B ---
    if wandb_args.wandb_project:
        try:
            wandb.init(
                project=wandb_args.wandb_project,
                name=wandb_args.wandb_run_name,
                config={
                    "model_args": model_args,
                    "data_args": data_args,
                    "validation_data_args": eval_args,
                    "training_args": training_args,
                    "eval_args": eval_args,
                }
            )
            logging.info(f"Logging to W&B project: {wandb_args.wandb_project}")
        except Exception as e:
            logging.warning(f"Could not initialize W&B: {e}. Disabling W&B logging.")
            wandb_args.wandb_project = None # Disable wandb if init fails
    
    # --- Load Model and Tokenizer ---
    logging.info(f"Loading model: {model_args.model_name_or_path}")
    tokenizer_path = model_args.tokenizer_name_or_path or model_args.model_name_or_path
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=False,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    try:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side="right")
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            quantization_config=bnb_config,
            device_map=model_args.device_map,
            torch_dtype=getattr(torch, model_args.torch_dtype) if hasattr(torch, model_args.torch_dtype) else "auto",
        )
        logging.info("Model and tokenizer loaded successfully.")
    except Exception as e:
        logging.error(f"Failed to load model or tokenizer: {e}")
        if wandb_args.wandb_project: wandb.finish(exit_code=1)
        sys.exit(1)

    # --- Load Dataset ---
    try:
        logging.info(f"Loading dataset: {data_args.dataset_name}, split: {data_args.train_split}")
        training_data = load_dataset(
            data_args.dataset_name,
            data_args.subset_name,
            split=data_args.train_split
        )
        logging.info(f"Loaded {len(training_data)} training examples.")
        logging.info(f"Loading validation dataset: {data_args.dataset_name}, split: {data_args.valid_split}")
        validation_data = load_dataset(
            data_args.dataset_name,
            data_args.subset_name,
            split=data_args.valid_split
        )
    except Exception as e:
        logging.error(f"Failed to load dataset: {e}")
        if wandb_args.wandb_project: wandb.finish(exit_code=1)
        sys.exit(1)
    
    # --- Custom Prompts ---
    custom_prompt_template = None
    if data_args.custom_prompt_file and data_args.custom_prompt_id:
        logging.info(f"Loading custom training prompt template from {data_args.custom_prompt_file} for ID {data_args.custom_prompt_id}")
        try:
            with open(data_args.custom_prompt_file, "r") as f:
                custom_prompts = json.load(f)
            for prompt in custom_prompts["prompts"]:
                logging.info(f"Checking prompt ID {data_args.custom_prompt_id} in {data_args.custom_prompt_file}")
                if prompt["prompt_id"] == data_args.custom_prompt_id:
                    custom_prompt_template = prompt["template"]
                    logging.info(f"Using custom prompt template: {custom_prompt_template}")
                    break
            if custom_prompt_template is None:
                logging.warning(f"Prompt ID {data_args.custom_prompt_id} not found in {data_args.custom_prompt_file}. Using default prompt template.")
        except Exception as e:
            logging.error(f"Failed to load custom prompt template: {e}. Using default prompt template.")
    else:
        logging.info("No custom prompt file or ID provided for training. Using default prompt template.")

    # --- Data Formatting ---
    eos_token = tokenizer.eos_token
    training_data = training_data.map(
        lambda examples: format_prompts(examples, input_column=data_args.input_column, output_column=data_args.output_column, custom_prompt_template=custom_prompt_template, eos_token=eos_token),
        batched=True,
    )
    validation_data = validation_data.map(
        lambda examples: format_prompts(examples, input_column=data_args.input_column, output_column=data_args.output_column, custom_prompt_template=custom_prompt_template, eos_token=eos_token),
        batched=True,
    )
    if ift_data_args.second_dataset_name:
        logging.info(f"Loading IFT dataset: {ift_data_args.second_dataset_name}, split: {ift_data_args.train_split}")
        ift_training_data = load_dataset(
            ift_data_args.second_dataset_name,
            ift_data_args.subset_name,
            split=ift_data_args.train_split
        )
        ift_training_data = ift_training_data.shuffle(seed=42).select(min(len(training_data), len(ift_training_data)))  # Ensure same size as training data
        logging.info(f"Loaded {len(ift_training_data)} IFT training examples.")
        ift_training_data = ift_training_data.map(
            lambda examples: format_prompts_for_ift(examples, input_columns=ift_data_args.input_columns, output_column=ift_data_args.output_column, eos_token=eos_token),
            batched=True,
        )
        training_data = training_data.concatenate(ift_training_data)
        training_data = training_data.shuffle(seed=42)

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )
    # --- Setup LoRA Configuration ---
    lora_config = LoraConfig(
        lora_alpha=lora_args.alpha,
        lora_dropout=lora_args.dropout,
        r=lora_args.rank,
        bias=lora_args.bias,
        task_type="CAUSAL_LM",
        target_modules= ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        init_lora_weights=lora_args.init_lora_weights
    )

    # --- Training ---
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=training_data,
        eval_dataset=validation_data,
        peft_config=lora_config,
        data_collator=data_collator,
    )

    torch.cuda.empty_cache()
    model.config.use_cache = False
    logging.info("Starting training...")
    trainer.train()
    logging.info("Training finished. Saving model...")
    trainer.save_model(training_args.output_dir)
    
    # --- Push to Hugging Face Hub ---
    if training_args.push_to_hub:
        logging.info(f"Pushing model to Hugging Face Hub: {training_args.output_dir}")
        trainer.push_to_hub(commit_message="Training complete")
    else:
        logging.info(f"Model saved to {training_args.output_dir} without pushing to Hugging Face Hub.")
    logging.info("Training complete.")
    if wandb_args.wandb_project:
        wandb.finish()
        logging.info("W&B run finished.")

if __name__ == "__main__":
    main()