import os
from openai import OpenAI
from datetime import datetime
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

fh = logging.FileHandler("chatgpt_finetune.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)


class OpenAIFT:
    client: OpenAI

    def __init__(self, client):
        self.client = client

    # Upload the JSONL file to OpenAI
    def upload_file(self, file_path):
        response = self.client.files.create(
            file=open(file_path, 'rb'),
            purpose='fine-tune'
        )
        return response.id

    # Fine-tune a model with the uploaded file and hyperparameters
    def fine_tune_model(self, training_file_id: str, validation_file_id: str, method_type: str, suffix: str, hyperparameters):
        job_info = self.client.fine_tuning.jobs.create(
            training_file=training_file_id,
            # validation_file=validation_file_id,  # violation of policy
            model=hyperparameters["model"],
            method={
                "type": method_type,
                method_type: {
                    "hyperparameters": {
                        "n_epochs": hyperparameters["n_epochs"],
                        "learning_rate_multiplier": hyperparameters["learning_rate_multiplier"],
                        "batch_size": hyperparameters["batch_size"],
                    }
                }
            },
            suffix=suffix,
            seed=42
        )
        job_id = job_info.id
        logger.info(f"Job ID: {job_id}")


def get_api_key():
    # First try to get from environment variable
    api_key = os.getenv('OPENAI_API_KEY')
    if api_key:
        return api_key
    # If not in env, try to read from api_key.txt
    try:
        with open('openai-key.txt', 'r') as file:
            key = file.read().strip()
            os.environ['OPENAI_API_KEY'] = key
            return key
    except FileNotFoundError:
        raise ValueError("OpenAI API key not found. Set OPENAI_API_KEY environment variable or create openai-key.txt file")


def main(model_name: str, loss_type: str, collection: str, split_type: str):
    input_dir = f"/mnt/dv/wid/projects3/XXXX-3-XXXX-5-human-ai/mini-twitter-llm-agent-binary/data/{"sft" if loss_type == "supervised" else "dpo"}_data_formatted"

    client = OpenAI(api_key=get_api_key())
    oaift = OpenAIFT(client)
    file_id_train = oaift.upload_file(os.path.join(input_dir, f"{model_name}_{collection}_{split_type}_train.jsonl" if loss_type == "dpo" else f"{collection}_{split_type}_train.jsonl"))
    file_id_test = oaift.upload_file(os.path.join(input_dir, f"{model_name}_{collection}_{split_type}_test.jsonl" if loss_type == "dpo" else f"{collection}_{split_type}_test.jsonl"))
    logger.info(f"Train File ID: {file_id_train}")
    logger.info(f"Test File ID: {file_id_test}")

    # Define your hyperparameters (as an example)
    fine_tune_hyperparameters = {
        "model": model_name,
        "n_epochs": 5,
        "batch_size": 256 if loss_type == "supervised" else 16,
        "learning_rate_multiplier": "auto",
    }
    
    date_str = datetime.now().strftime("%Y%m%d")
    suffix = f"{"sft" if loss_type == "supervised" else "dpo"}-{date_str}:{collection}-{split_type}-{fine_tune_hyperparameters["n_epochs"]}epochs"
    logger.info(f"Model Name: {model_name}")
    logger.info(f"Suffix: {suffix}")

    oaift.fine_tune_model(file_id_train, file_id_test, loss_type, suffix, fine_tune_hyperparameters)


if __name__ == "__main__":
    for model_name in ["gpt-4o-mini-2024-07-18", "gpt-4.1-nano-2025-04-14"]:
        available_loss_types = ["supervised", "dpo"] if model_name == "gpt-4.1-nano-2025-04-14" else ["supervised"]
        for loss_type in available_loss_types:
            for collection in ["depth", "breadth"]:
                for split_type in ["round", "topic", "group"]:
                    main(model_name=model_name, loss_type=loss_type, collection=collection, split_type=split_type)
