# traine_direction.py
from datasets import Dataset
from huggingface_hub.utils import disable_progress_bars
import json
import os, subprocess
import torch
from trl.trainer import SPOConfig, SPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
from huggingface_hub import login, HfApi
from transformers.integrations import MLflowCallback
from dotenv import load_dotenv
import shutil

# Disable progress bars for cleaner output
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["GIT_LFS_PROGRESS"] = "false"
os.environ["TQDM_DISABLE"] = "1"  # tqdm disable
disable_progress_bars()
# subprocess.run(["git", "config", "lfs.progress", "false"], check=True) # only affect current repository

load_dotenv(dotenv_path='./.env')
hf_token = os.environ.get("HF_TOKEN")
login(token=hf_token)
cache_dir = os.getenv("CACHE_DIR", "~/.cache")
cache_dir = os.path.expanduser(cache_dir)
api = HfApi()


def success_repo(hub_model_id):
    try:    
        files = api.list_repo_files(repo_id=hub_model_id, repo_type="model")
        if 'tokenizer.json' in files: return True
        else: return False
    except Exception as e:
        return False
    
def train_direction(
    seed: int = 0,
    dataset: str | None = None,
    instance_num: int = 5000,
    learning_rate: float = 1e-5,
    beta: float = 1.0,
    constrained_logp: bool = False,
    mname: str | None = None,
    languages: list[str] | None = None,
    lang1_learning_strength: float | None = None,
    lang2_learning_strength: float | None = None,
    use_false_examples: bool = False,
    do_logging: bool = False,
) -> None:
    # if os.path.exists(cache_dir): shutil.rmtree(cache_dir)
    
    # Load the model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(mname, torch_dtype=torch.bfloat16, device_map="auto", cache_dir=cache_dir)
    tokenizer = AutoTokenizer.from_pretrained(mname)
    # Ensure the tokenizer has a pad token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(languages)
    print(type(languages))

    if not use_false_examples:
        # Load the training dataset
        train_dataset = Dataset.from_list(json.load(open(f"data/seed{seed}_sample{instance_num}_{dataset}/{'-'.join(languages)}.json")))
        hub_model_id = f"seed{seed}_sample{instance_num}_{dataset}_{mname.replace('/', '-')}_{'-'.join(languages)}_{lang1_learning_strength}-{lang2_learning_strength}_{beta}"
    else:
        # Load the training dataset with false examples
        train_dataset = Dataset.from_list(json.load(open(f"data/seed{seed}_sample{instance_num}_{dataset}_false/{'-'.join(languages)}.json")))
        hub_model_id = f"seed{seed}_sample{instance_num}_{dataset}_false_{mname.replace('/', '-')}_{'-'.join(languages)}_{lang1_learning_strength}-{lang2_learning_strength}_{beta}"

    if dataset == 'bmlama':
        per_device_train_batch_size = 4
        gradient_accumulation_steps = 1
        learning_rate = 1e-5
        num_train_epochs = 1
    elif dataset == 'mmmlu':
        per_device_train_batch_size = 4
        gradient_accumulation_steps = 1
        learning_rate = 1e-5
        num_train_epochs = 1
    elif dataset == 'xcsqa':
        per_device_train_batch_size = 4
        gradient_accumulation_steps = 1
        num_train_epochs = 1
        if "gemma-3-1b-pt" in mname.lower():
            learning_rate = 1e-6
        elif "gemma-3-4b-pt" in mname.lower():
            learning_rate = 1e-6
        elif "gemma-3-12b-pt" in mname.lower():
            learning_rate = 1e-6
        elif "llama-3.1-8b" in mname.lower():
            learning_rate = 1e-6
        else:
            learning_rate = 1e-5
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

    training_args = SPOConfig(
        learning_rate=learning_rate,
        beta=beta,
        constrained_logp=constrained_logp,
        output_dir=f"checkpoints/{hub_model_id}",
        bf16=True,
        precompute_ref_log_probs=False,
        lang1_learning_strength=lang1_learning_strength,
        lang2_learning_strength=lang2_learning_strength,
        per_device_train_batch_size=per_device_train_batch_size,
        num_train_epochs=num_train_epochs,
        gradient_accumulation_steps=gradient_accumulation_steps,
        save_strategy='no',
        # save_strategy='steps',
        # save_steps=10000,
        # logging_strategy="epoch",
        logging_strategy="steps",
        logging_steps=100,
        disable_tqdm=not do_logging,
        optim="adamw_torch_fused",
        push_to_hub=True,
        report_to=[],
        # report_to="wandb",
        # hub_private_repo=True, # a private repo
        hub_model_id=f"{YOUR_HUB_ID}/{hub_model_id}",
        hub_strategy="end",
        hub_token=hf_token,
    )

    ref_model = AutoModelForCausalLM.from_pretrained(mname, torch_dtype=torch.bfloat16, device_map="auto")

    trainer = SPOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
    # Remove the MLflow integration entirely
    if not do_logging:
        print("Disabling MLflow logging")
        trainer.remove_callback(MLflowCallback)  # cuts out MLflow logging

    trainer.train()
    trainer.push_to_hub()

    if os.path.exists(f"checkpoints/{hub_model_id}"): shutil.rmtree(f"checkpoints/{hub_model_id}")

    print(f"Finished training {hub_model_id} on {languages} with strengths {lang1_learning_strength}-{lang2_learning_strength}")
    print('====')
    return
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0, help='random seed for data generation')
    parser.add_argument('--dataset', type=str, default='bmlama', help='dataset name')
    parser.add_argument('--beta', type=float, default=1.0, help='beta value')
    parser.add_argument('--learning_rate', type=float, default=1e-5, help='learning rate')
    parser.add_argument('--constrained_logp', action='store_true', help='whether to constrain logp')
    parser.add_argument('--instance_num', type=int, default=5000, help='number of instances')
    parser.add_argument('--mname', type=str, default='Qwen/Qwen3-4B', help='model name')
    parser.add_argument('--languages', nargs='+', default=['en', 'fr'], help='languages')
    parser.add_argument('--lang1_learning_strength', type=float, default=1.0, help='learning strength for language 1')
    parser.add_argument('--lang2_learning_strength', type=float, default=1.0, help='learning strength for language 2')
    parser.add_argument('--use_false_examples', action='store_true', help='whether to use false examples')
    parser.add_argument('--do_logging', action='store_true', help='whether to do MLflow logging')

    args = parser.parse_args()
    seed = args.seed
    dataset = args.dataset
    instance_num = args.instance_num
    mname = args.mname
    languages = args.languages
    lang1_learning_strength = args.lang1_learning_strength
    lang2_learning_strength = args.lang2_learning_strength
    use_false_examples = args.use_false_examples
    do_logging = args.do_logging
    # constrained_logp = args.constrained_logp
    constrained_logp = True

    train_direction(
        seed=seed,
        beta=args.beta,
        learning_rate=args.learning_rate,
        constrained_logp=constrained_logp,
        dataset=dataset,
        instance_num=instance_num,
        mname=mname,
        languages=languages,
        lang1_learning_strength=lang1_learning_strength,
        lang2_learning_strength=lang2_learning_strength,
        use_false_examples=use_false_examples,
        do_logging=do_logging
    )
    