import os
import hydra

import torch
import transformers

from pathlib import Path
from omegaconf import OmegaConf
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

from data_module import TextDatasetQA, custom_data_collator
from dataloader import CustomTrainer
from utils import get_model_identifiers_from_yaml, find_all_linear_names

# Disable DeepSpeed's custom CUDA kernels
os.environ["DS_BUILD_OPS"] = "0"
os.environ["DS_BUILD_SPARSE_ATTN"] = "0"
os.environ["DS_BUILD_CPU_ADAM"] = "0"

@hydra.main(version_base=None, config_path="config", config_name="finetune")
def main(cfg):
    num_devices = int(os.environ.get('WORLD_SIZE', 1))
    print(f"num_devices: {num_devices}")

    # Set local_rank and device_map based on distributed settings.
    if os.environ.get('LOCAL_RANK') is not None:
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        device_map = {'': local_rank}
    else:
        local_rank = 0
        device_map = "auto"  # let Hugging Face automatically place the model

    # Disable DeepSpeed for single GPU runs
    ds_config_path = None
    if num_devices > 1 and hasattr(cfg, 'parallelism') and cfg.parallelism.strategy:
        # Use DeepSpeed only for multi-GPU runs
        if cfg.parallelism.strategy == "pipeline":
            ds_config_path = 'config/ds_pipeline_config.json'
        elif cfg.parallelism.strategy == "tensor":
            ds_config_path = 'config/ds_tensor_config.json'
        elif cfg.parallelism.strategy == "zero":
            ds_config_path = 'config/ds_zero_config.json'
        else:
            ds_config_path = 'config/ds_config.json'
        print(f"Using parallelism strategy: {cfg.parallelism.strategy}")
    elif num_devices > 1:
        ds_config_path = 'config/ds_config.json'
        print("Using default parallelism strategy")
    else:
        print("Single GPU detected, disabling DeepSpeed")

    set_seed(cfg.seed)
    os.environ["WANDB_DISABLED"] = "true"
    model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
    model_id = model_cfg["hf_key"]

    Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)
    # Save the cfg file only on the master process.
    if os.environ.get('LOCAL_RANK') is None or local_rank == 0:
        OmegaConf.save(cfg, f'{cfg.save_dir}/cfg.yaml')

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    max_length = 500
    torch_format_dataset = TextDatasetQA(
        cfg.data_path,
        tokenizer=tokenizer,
        model_family = cfg.model_family,
        max_length=max_length,
        split=cfg.split
    )

    batch_size = cfg.batch_size
    gradient_accumulation_steps = cfg.gradient_accumulation_steps
    max_steps = (int(cfg.num_epochs * len(torch_format_dataset)) // 
                 (batch_size * gradient_accumulation_steps * num_devices))
    print(f"max_steps: {max_steps}")

    # Update training args to handle single-GPU case
    training_args = transformers.TrainingArguments(
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=max(1, max_steps//cfg.num_epochs),
        max_steps=max_steps,
        learning_rate=cfg.lr,
        bf16=True,
        bf16_full_eval=True,
        logging_steps=max(1,max_steps//20),
        logging_dir=f'{cfg.save_dir}/logs',
        output_dir=cfg.save_dir,
        optim="adamw_torch",  # Using PyTorch's AdamW optimizer
        save_steps=max_steps,
        save_only_model=True,
        ddp_find_unused_parameters=False,
        deepspeed=ds_config_path,  # This will be None for single GPU
        weight_decay=cfg.weight_decay,
        seed=cfg.seed,
        save_strategy="epoch",
        eval_strategy="no",
    )
    use_flash_attention_2 = cfg.use_flash_attention_2 and model_cfg["flash_attention2"] == "true"
    # Load model with specific settings based on GPU count
    model_kwargs = {
        "use_flash_attention_2": use_flash_attention_2,
        "torch_dtype": torch.bfloat16,
        "trust_remote_code": True,
    }
    
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        **model_kwargs
    )

    # Hot fix for https://discuss.huggingface.co/t/help-with-llama-2-finetuning-setup/50035
    model.generation_config.do_sample = True

    if model_cfg["gradient_checkpointing"] == "true":
        model.gradient_checkpointing_enable()

    if cfg.LoRA.r != 0:
        lora_config = LoraConfig(
            r=cfg.LoRA.r, 
            lora_alpha=cfg.LoRA.alpha, 
            target_modules=find_all_linear_names(model), 
            lora_dropout=cfg.LoRA.dropout,
            bias="none", 
            task_type="CAUSAL_LM"
        )
        model = get_peft_model(model, lora_config)
        model.enable_input_require_grads()

    # Optional profiling capability
    enable_profiling = hasattr(cfg, 'profiling') and cfg.profiling.enabled

    trainer = CustomTrainer(
        model=model,
        train_dataset=torch_format_dataset,
        eval_dataset=torch_format_dataset,
        args=training_args,
        data_collator=custom_data_collator,
        enable_profiling=enable_profiling,
    )

    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
    trainer.train()

    # save the model
    if cfg.LoRA.r != 0:
        model = model.merge_and_unload()

    model.save_pretrained(cfg.save_dir)
    tokenizer.save_pretrained(cfg.save_dir)

if __name__ == "__main__":
    main()
    print("finetune finished.")
