# nohup python forget.py > forget.log 2>&1 &
import os
import re
import shutil
import json
import hydra

import torch
import transformers

from pathlib import Path
from omegaconf import OmegaConf
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, set_seed
from peft import LoraConfig, get_peft_model, PeftModel

from data_module import TextForgetDatasetQA, TextForgetDatasetDPOQA
from dataloader import CustomTrainerForgetting, custom_data_collator_forget, custom_data_collator_forget_dpo
from utils import get_model_identifiers_from_yaml, find_all_linear_names, print_trainable_parameters

import warnings
warnings.filterwarnings("ignore")

# Disable DeepSpeed's custom CUDA kernels
os.environ["DS_BUILD_OPS"] = "0"
os.environ["DS_BUILD_SPARSE_ATTN"] = "0"
os.environ["DS_BUILD_CPU_ADAM"] = "0"

def get_deepspeed_zero_stage(ds_config_path):
    if ds_config_path is None:
        return 0

    try:
        with open(ds_config_path, 'r') as f:
            ds_config = json.load(f)
        return ds_config.get("zero_optimization", {}).get("stage", 0)
    except:
        return 0

@hydra.main(version_base=None, config_path="config", config_name="forget")
def main(cfg):
    print("\n\n\n")
    print("Model Path:", cfg.model_path)
    print("Influence Score Path:", cfg.score_dict_path)
    print("Forget on", cfg.split)
    print("Checkpoint will be saved at", cfg.save_dir, "\n\n\n\n\n")

    num_devices = int(os.environ.get('WORLD_SIZE', 1))
    print(f"num_devices: {num_devices}")

    # Set local_rank and device_map based on distributed settings.
    if os.environ.get('LOCAL_RANK') is not None:
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        device_map = 'auto' #{'': local_rank}
    else:
        local_rank = 0
        device_map = "auto"  # Let Hugging Face decide

    # Disable DeepSpeed for single GPU runs
    ds_config_path = None
    if num_devices > 1 and hasattr(cfg, 'parallelism') and cfg.parallelism.strategy:
        # Use DeepSpeed only for multi-GPU runs
        if cfg.parallelism.strategy == "pipeline":
            ds_config_path = 'config/ds_pipeline_config.json'
        elif cfg.parallelism.strategy == "tensor":
            ds_config_path = 'config/ds_tensor_config.json'
        elif cfg.parallelism.strategy == "zero":
            ds_config_path = 'config/ds_zero_config.json'
        else:
            ds_config_path = 'config/ds_config.json'
        print(f"Using parallelism strategy: {cfg.parallelism.strategy}")
    elif num_devices > 1:
        ds_config_path = 'config/ds_config.json'
        print("Using default parallelism strategy")
    else:
        print("Single GPU detected, disabling DeepSpeed")

    set_seed(cfg.seed)
    os.environ["WANDB_DISABLED"] = "true"
    model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
    model_id = model_cfg["hf_key"]

    # If model_path is not provided, use the default from the model config.
    if cfg.model_path is None:
        cfg.model_path = model_cfg["ft_model_path"]

    print("######################")
    print("Saving to: ", cfg.save_dir)
    print("######################")
    # Save configuration only on the master process.
    if local_rank == 0:
        if os.path.exists(cfg.save_dir):
            print("Directory already exists")
            if not cfg.overwrite_dir:
                exit()
        Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)
        OmegaConf.save(cfg, f"{cfg.save_dir}/config.yaml")

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    # Ensure the pad token is set (defaulting to eos_token if necessary)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    max_length = 500
    if cfg.forget_loss == "dpo":
        torch_format_dataset = TextForgetDatasetDPOQA(
            cfg.data_path,
            tokenizer=tokenizer,
            model_family=cfg.model_family,
            max_length=max_length,
            split=cfg.split,
        )
    else:
        torch_format_dataset = TextForgetDatasetQA(
            cfg.data_path,
            tokenizer=tokenizer,
            model_family=cfg.model_family,
            max_length=max_length,
            split=cfg.split,
            loss_type=cfg.forget_loss,
        )

    batch_size = cfg.batch_size
    gradient_accumulation_steps = cfg.gradient_accumulation_steps
    steps_per_epoch = len(torch_format_dataset) // (batch_size * gradient_accumulation_steps * num_devices)
    max_steps = (int(cfg.num_epochs * len(torch_format_dataset))
                 // (batch_size * gradient_accumulation_steps * num_devices))
    print(f"max_steps: {max_steps}")

    # Update optimizer for compatibility
    training_args = transformers.TrainingArguments(
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=max(1, steps_per_epoch),
        max_steps=max_steps,
        learning_rate=cfg.lr,
        bf16=True,
        bf16_full_eval=True,
        logging_steps=max(1, max_steps // 20),
        logging_dir=f'{cfg.save_dir}/logs',
        output_dir=cfg.save_dir,
        optim="paged_adamw_32bit",  # Changed from paged_adamw_32bit for compatibility
        save_strategy="steps" if cfg.save_model and (not cfg.eval_only) else "no",
        save_steps=steps_per_epoch,
        save_only_model=True,
        ddp_find_unused_parameters=False,
        deepspeed=ds_config_path,  # Will be None for single GPU
        weight_decay=cfg.weight_decay,
        eval_steps=steps_per_epoch,
        eval_strategy="steps" if cfg.eval_while_train else "no",
        seed=cfg.seed
    )

    # Determine if there is a checkpoint file in cfg.model_path.
    path_found = False
    for file in os.listdir(cfg.model_path):
        if re.search(r"pytorch.*\.bin", file):
            path_found = True
            break
        if re.search(r"model.*\.safetensors", file):
            path_found = True
            break

    oracle_model = None

    # Unified model loading - works for both single and multi-GPU
    # Get DeepSpeed Zero stage
    zero_stage = get_deepspeed_zero_stage(ds_config_path)
    print(f"DeepSpeed Zero stage: {zero_stage}")

    use_flash_attention_2 = cfg.use_flash_attention_2 and model_cfg["flash_attention2"] == "true"
    # Now modify your model loading code:
    if path_found:
        config = AutoConfig.from_pretrained(model_id)
        print("Loading from checkpoint")

        # Set device_map based on DeepSpeed Zero-3 compatibility
        model_device_map = None if (zero_stage == 3 and num_devices > 1) else "auto"
        print(f"Using device_map: {model_device_map}")

        model = AutoModelForCausalLM.from_pretrained(
            cfg.model_path,
            config=config,
            use_flash_attention_2=use_flash_attention_2, #(model_cfg["flash_attention2"] == "true"),
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map=model_device_map  # Set to None for DeepSpeed Zero-3
        )

        if cfg.forget_loss == "KL" or cfg.forget_loss == "dpo":
            oracle_model = AutoModelForCausalLM.from_pretrained(
                cfg.model_path,
                config=config,
                use_flash_attention_2=use_flash_attention_2,
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
                device_map=model_device_map  # Same change here
            )
    else:
        print("Loading after merge and unload")
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            use_flash_attention_2=use_flash_attention_2,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        # Use the checkpoint to add the LoRA modules.
        model = PeftModel.from_pretrained(model, model_id=cfg.model_path)
        # Merge the LoRA weights into the base model.
        model = model.merge_and_unload()
        # Save the merged model for future use.
        model.save_pretrained(cfg.model_path)

    # Hot fix for Llama 2 finetuning
    model.generation_config.do_sample = True

    # Enable gradient checkpointing if configured.
    if model_cfg["gradient_checkpointing"] == "true":
        model.gradient_checkpointing_enable()

    lora_config = LoraConfig(
        r=cfg.LoRA.r,
        lora_alpha=cfg.LoRA.alpha,
        target_modules=find_all_linear_names(model),
        lora_dropout=cfg.LoRA.dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )
    if cfg.LoRA.r != 0:
        model = get_peft_model(model, lora_config)
        print_trainable_parameters(model)

    # Eve: load score table and initialize the forget trainer with the score table
    # Load JSON from a file
    with open(cfg.score_dict_path, "r") as f:
        score_dict = json.load(f)

    data_collector = custom_data_collator_forget if cfg.forget_loss != "dpo" else custom_data_collator_forget_dpo 

    trainer = CustomTrainerForgetting(
        model=model,
        tokenizer=tokenizer,
        train_dataset=torch_format_dataset,
        eval_dataset=torch_format_dataset,
        compute_metrics=None,  # The callback for computing metrics (None in this case).
        args=training_args,
        data_collator=data_collector,
        oracle_model=oracle_model,
        forget_loss=cfg.forget_loss,
        eval_cfg=cfg.eval,
        score_dict=score_dict
    )

    model.config.use_cache = False  # Disable cache during training to silence warnings.
    if cfg.eval_only:
        trainer.evaluate()
    else:
        trainer.train()

    # Save the model and tokenizer if requested.
    if cfg.save_model and (not cfg.eval_only):
        model.save_pretrained(cfg.save_dir)
        tokenizer.save_pretrained(cfg.save_dir)

    # Delete all "global_step*" directories inside any checkpoint directories.
    if local_rank == 0:
        for checkpoint_dir in Path(cfg.save_dir).glob("checkpoint-*"):
            for global_step_dir in checkpoint_dir.glob("global_step*"):
                shutil.rmtree(global_step_dir)

if __name__ == "__main__":
    main()
    print("forget completed.")
