from data_module import TextForgetDatasetQA, TextForgetDatasetDPOQA
from dataloader import CustomTrainerForgetting, custom_data_collator_forget
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, set_seed
import deepspeed
import json
import hydra 
import transformers
import os
from peft import LoraConfig, get_peft_model, PeftModel, set_peft_model_state_dict
from sinelora_config import sinLoraConfig, sinDoraConfig, create_sine_lora_model

from pathlib import Path
from utils import get_model_identifiers_from_yaml
from omegaconf import OmegaConf
from functools import reduce

import random
import numpy as np

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            print(name)
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

@hydra.main(version_base=None, config_path="config", config_name="forget")
def main(cfg):

    # FIXED: Handle save directory naming ONCE at the beginning
    if cfg.importance_file: # if using FILA
        print("Using FILA...")
        cfg.save_dir = cfg.save_dir.replace(cfg.forget_loss, cfg.forget_loss+"_FILA")
    elif hasattr(cfg, 'use_sinelora') and cfg.use_sinelora:
        print("Using Sine based Parameter-Efficient approach...")
        cfg.save_dir = cfg.save_dir.replace(cfg.forget_loss, cfg.forget_loss+"_SineParameterEfficient")
    
    print("######################")
    print("Saving to: ", cfg.save_dir)
    print("######################")

    num_devices = int(os.environ.get('WORLD_SIZE', 1))
    print(f"num_devices: {num_devices}")

    if os.environ.get('LOCAL_RANK') is not None:
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        device_map = {'': local_rank}

    set_seed(cfg.seed)

    os.environ["WANDB_DISABLED"] = "true"
    model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
    model_id = model_cfg["hf_key"]
    if cfg.model_path is None:
        cfg.model_path = model_cfg["ft_model_path"]

    Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)
    # save cfg in cfg.save_dir
    if local_rank == 0:
        with open(f"{cfg.save_dir}/config.yaml", "w") as file:
            OmegaConf.save(cfg, file)

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    # tokenizer = AutoTokenizer.from_pretrained(cfg.model_path)  # Use model_id for tokenizer
    tokenizer.pad_token = tokenizer.eos_token


    if os.path.exists(cfg.save_dir):
        print("Directory already exists")
        if not cfg.overwrite_dir:
            exit()

    max_length = 500
    if cfg.forget_loss == "dpo":
        torch_format_dataset = TextForgetDatasetDPOQA(
            cfg.data_path, 
            tokenizer=tokenizer, 
            model_family = cfg.model_family, 
            max_length=max_length, 
            split=cfg.split
        )
    else:
        torch_format_dataset = TextForgetDatasetQA(
            cfg.data_path, 
            tokenizer=tokenizer, 
            model_family=cfg.model_family, 
            max_length=max_length, 
            split=cfg.split, 
            loss_type=cfg.forget_loss
        )
    
    batch_size = cfg.batch_size
    gradient_accumulation_steps = cfg.gradient_accumulation_steps
    steps_per_epoch = len(torch_format_dataset)//(batch_size*gradient_accumulation_steps*num_devices)

    max_steps = int(cfg.num_epochs*len(torch_format_dataset))//(batch_size*gradient_accumulation_steps*num_devices)
    print(f"max_steps: {max_steps}")

    training_args = transformers.TrainingArguments(
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=max(1, steps_per_epoch),
        max_steps=max_steps,
        learning_rate=cfg.lr,
        bf16=True,
        bf16_full_eval=True,
        logging_steps=max(1,max_steps//20),
        logging_dir=f'{cfg.save_dir}/logs',
        output_dir=cfg.save_dir,
        optim="paged_adamw_32bit",
        save_strategy="steps" if cfg.save_model and (not cfg.eval_only) else "no",
        save_steps=steps_per_epoch,
        save_only_model=True,
        ddp_find_unused_parameters= False,
        deepspeed='config/ds_config.json',
        weight_decay = cfg.weight_decay,
        eval_steps = steps_per_epoch,
        eval_strategy = "steps" if cfg.eval_while_train else "no",
        seed=cfg.seed
    )
    
    #first get the base model architectur2e
    #if there is a pytorch*.bin file in the model path, then load that. use regex there can be anythign in between pytorch and .bin
    import re
    path_found = True

    # # Check if it's a local path or HuggingFace model ID
    # if os.path.exists(cfg.model_path):
    #     # Local path - use original logic
    #     for file in os.listdir(cfg.model_path):
    #         if re.search("pytorch.*\.bin", file):
    #             path_found = True
    #             break
            
    #         if re.search("model-*\.safetensors", file):
    #             path_found = True
    #             break
    # else:
    #     # HuggingFace model ID - assume it has model files
    #     path_found = True

    oracle_model = None

    if path_found:
        print("Loading from checkpoint")
        print(f"Model path: {cfg.model_path}")
        print(f"Flash attention: {model_cfg['flash_attention2']}")
        if model_cfg["flash_attention2"] == True:
            model = AutoModelForCausalLM.from_pretrained(
                cfg.model_path,
                attn_implementation="flash_attention_2",
                torch_dtype=torch.bfloat16,
                trust_remote_code = True,
                low_cpu_mem_usage=True
            )
        else:
            # Special handling for Phi models
            if cfg.model_family == 'phi':
                print("Loading Phi model with base model config...")
                # Load config from base model to avoid missing files
                base_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
                model = AutoModelForCausalLM.from_pretrained(
                    cfg.model_path,
                    config=base_config,
                    torch_dtype=torch.bfloat16,
                    trust_remote_code=True,
                    low_cpu_mem_usage=True
                )
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    cfg.model_path, 
                    torch_dtype=torch.bfloat16, 
                    trust_remote_code = True
                )
        if cfg.forget_loss == "KL" or "npo" in cfg.forget_loss:
            if model_cfg["flash_attention2"] == True:
                oracle_model = AutoModelForCausalLM.from_pretrained(
                    cfg.model_path,
                    attn_implementation="flash_attention_2",
                    torch_dtype=torch.bfloat16,
                    trust_remote_code = True,
                    low_cpu_mem_usage=True
                )
            else:
                oracle_model = AutoModelForCausalLM.from_pretrained(
                    cfg.model_path,
                    torch_dtype=torch.bfloat16,
                    trust_remote_code = True,
                    low_cpu_mem_usage=True
                )

    else:
        print("Loading after merge and unload")
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            # use_flash_attention_2=model_cfg["flash_attention2"]=="true",
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
            device_map=device_map,
            low_cpu_mem_usage=True
        )
        #now use the checkpoint to add the LoRA modules
        model = PeftModel.from_pretrained(model, model_id=cfg.model_path)
        #save this as a standard model so that we can again do PEFT style finetuneing from scratch
        model = model.merge_and_unload()
        #save the model for next time
        model.save_pretrained(cfg.model_path)
    
    
    # Hot fix for https://discuss.huggingface.co/t/help-with-llama-2-finetuning-setup/50035
    model.generation_config.do_sample = True
    
    # NOTE: Gradient checkpointing is DISABLED for Sine based Parameter-Efficient compatibility
    # Sine based Parameter-Efficient approach does not work properly with gradient checkpointing enabled
    # Keep this commented out for llama3.2 models using Sine based Parameter-Efficient approach
    #if model_cfg["gradient_checkpointing"] == True:
    #    model.gradient_checkpointing_enable()

    # DEBUG: Print actual config values
    print(f"🔍 DEBUG: model_family = {cfg.model_family}")
    print(f"🔍 DEBUG: cfg.LoRA.targets = '{cfg.LoRA.targets}' (type: {type(cfg.LoRA.targets)})")
    print(f"🔍 DEBUG: Full LoRA config = {cfg.LoRA}")
    
    if cfg.model_family == 'llama2-7b':
        if cfg.LoRA.targets == "all":
            lora_targets = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
        elif cfg.LoRA.targets == "self_attn":
            lora_targets = ['q_proj', 'k_proj', 'v_proj', 'o_proj']
        elif cfg.LoRA.targets == "mlp":
            lora_targets = ['gate_proj', 'up_proj', 'down_proj']
        else:
            raise NotImplementedError
    elif cfg.model_family in ['llama3.1-8b', 'llama3.2-1b', 'llama3.2-3b', 'llama3.2-8b']:
        if cfg.LoRA.targets == "all":
            lora_targets = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
        elif cfg.LoRA.targets == "self_attn":
            lora_targets = ['q_proj', 'k_proj', 'v_proj', 'o_proj']
        elif cfg.LoRA.targets == "mlp":
            lora_targets = ['gate_proj', 'up_proj', 'down_proj']
        else:
            raise NotImplementedError
    elif cfg.model_family == 'phi':
        if cfg.LoRA.targets == "all":
            lora_targets = ['q_proj', 'k_proj', 'v_proj', 'dense', 'fc1', 'fc2']
        elif cfg.LoRA.targets == "self_attn":
            lora_targets = ['q_proj', 'k_proj', 'v_proj', 'dense']
        elif cfg.LoRA.targets == "mlp":
            lora_targets = ['fc1', 'fc2']
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError
    
    # FIXED: Create LoRA or Sine based Parameter-Efficient configuration
    if hasattr(cfg, 'use_sinelora') and cfg.use_sinelora:
        print("Using Sine based Parameter-Efficient configuration...")
        config = sinLoraConfig(
            r=cfg.LoRA.r, 
            lora_alpha=cfg.LoRA.alpha, 
            target_modules=lora_targets,
            lora_dropout=cfg.LoRA.dropout,
            bias="none", 
            task_type="CAUSAL_LM",
            s=getattr(cfg, 'sine_scale', 1.0),
            freq=getattr(cfg, 'sine_freq', 1)
        )
    else:
        print("Using standard LoRA configuration...")
        config = LoraConfig(
            r=cfg.LoRA.r, 
            lora_alpha=cfg.LoRA.alpha, 
            target_modules=lora_targets,
            lora_dropout=cfg.LoRA.dropout,
            bias="none", 
            task_type="CAUSAL_LM"
        )
    
    if cfg.LoRA.r != 0:
        print(f"Applying LoRA with r={cfg.LoRA.r}, targets={lora_targets}")
        print(f"Using Sine based Parameter-Efficient approach: {getattr(cfg, 'use_sinelora', False)}")

        # Use custom sine LoRA implementation if enabled
        if hasattr(cfg, 'use_sinelora') and cfg.use_sinelora:
            print("APPLYING CUSTOM SINE LORA IMPLEMENTATION...")
            print(f"   Sine frequency (omega): {config.freq}")
            print(f"   Sine scale (s): {config.s}")
            model = create_sine_lora_model(model, config)
            print("Custom Sine LoRA applied successfully!")
        else:
            model = get_peft_model(model, config)
            print("Standard LoRA applied successfully!")
        print_trainable_parameters(model)
        
        # Verify LoRA was actually applied
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        lora_percentage = 100 * trainable_params / total_params
        
        if lora_percentage > 5.0:  # Should be much less than 5% for LoRA
            print(f"WARNING: {lora_percentage:.2f}% trainable params is too high for LoRA!")
            print("This suggests LoRA was not applied correctly.")
        else:
            print(f"LoRA successfully applied: {lora_percentage:.2f}% trainable params")
        
        # VERIFICATION: Check if Sine LoRA was actually applied
        if hasattr(cfg, 'use_sinelora') and cfg.use_sinelora:
            print("VERIFYING SINE LORA IMPLEMENTATION...")

            # Look for SineLoraLinear modules
            sine_modules = []
            for name, module in model.named_modules():
                if 'SineLoraLinear' in str(type(module)):
                    sine_modules.append(name)
                    print(f"   SineLoraLinear found: {name}")
                    if hasattr(module, 'freq') and hasattr(module, 's'):
                        print(f"     Frequency (omega): {module.freq}")
                        print(f"     Scale (s): {module.s}")

            if sine_modules:
                print(f"SUCCESS: {len(sine_modules)} Sine LoRA layers applied!")
                print("sin(omega * AB^T) transformation is active.")
            else:
                print("WARNING: No SineLoraLinear modules found!")
                print("Standard LoRA may be running instead of Sine LoRA.")
            print("-" * 50)

    ### MODIFY MODEL WEIGHTS BASED ON IMPORTANCES (FILA ONLY)
    def get_module_by_name(module, access_string):
        names = access_string.split(sep='.')
        return reduce(getattr, names, module)

    if cfg.importance_file: # if using FILA
        print(f'loading importance file from {cfg.importance_file}')
        imp_file = torch.load(cfg.importance_file, map_location='cpu')
        
        f_cnt = imp_file['f_cnt']
        r_cnt = imp_file['r_cnt']
        importance_f = imp_file['importance_f']
        importance_r = imp_file['importance_r']
        
        importances = {n: torch.div(importance_f[n]/f_cnt, 1e-5+(importance_r[n]/r_cnt)) for n in importance_f.keys()}

        for old_name, importance in importances.items():

            if not any([target_name in old_name for target_name in lora_targets]):
                continue
            name = old_name.replace("module.", '')
            lora_A = 'base_model.model.'+name.replace(".weight", '')+'.lora_A'
            lora_B = 'base_model.model.'+name.replace(".weight", '')+'.lora_B'
            base_layer = 'base_model.model.'+name.replace(".weight", '')+'.base_layer'
            scaling = 'base_model.model.'+name.replace(".weight", '')+'.scaling'

            lora_A = get_module_by_name(model, lora_A)
            lora_B = get_module_by_name(model, lora_B)
            base_layer = get_module_by_name(model, base_layer)
            scaling = get_module_by_name(model, scaling)

            orig_shape = base_layer.weight.shape
            W = base_layer.weight.data.reshape(orig_shape)
            dtype = W.dtype
            W = W.to(torch.float32)

            # Solve row-wise weighted low-rank approximation
            row_importance = importance.sum(dim=1).sqrt().to(W.device) # row-wise sum
            U, S, V = torch.svd_lowrank(row_importance[:,None] * W, q=cfg.LoRA.r)

            S = S / scaling['default']

            new_lora_A = (V * torch.sqrt(S)).t()
            new_lora_B = (1/(row_importance+1e-5))[:,None] * (U * torch.sqrt(S))
            new_residual = base_layer.weight.data.reshape(orig_shape) - scaling['default'] * new_lora_B @ new_lora_A

            lora_A['default'].weight.data = new_lora_A.contiguous().to(dtype)
            lora_B['default'].weight.data = new_lora_B.contiguous().to(dtype)
            base_layer.weight.data = new_residual.contiguous().to(dtype)
    
    # REMOVED: Manual sine initialization - sinLoraConfig handles this automatically
    
    trainer = CustomTrainerForgetting(
        model=model,
        tokenizer=tokenizer,
        train_dataset=torch_format_dataset,
        eval_dataset=torch_format_dataset,
        compute_metrics=None, # the callback for computing metrics, None in this case since you're doing it in your callback
        # callbacks=[GlobalStepDeletionCallback], 
        args=training_args,
        data_collator=custom_data_collator_forget,
        oracle_model = oracle_model,
        forget_loss = cfg.forget_loss,
        eval_cfg = cfg.eval,
        seed = cfg.seed, # for NPO
        ref_policy = cfg.ref_policy,
        beta = cfg.beta,
        npo_coeff=cfg.npo_coeff,
        grad_diff_coeff=cfg.grad_diff_coeff,
        KL_coeff=cfg.KL_coeff,
    )
    
    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
    if cfg.eval_only:
        trainer.evaluate()
    else:
        trainer.train()

    if cfg.save_model and (not cfg.eval_only):
        if cfg.importance_file:
            model = model.unload()
        model.save_pretrained(cfg.save_dir)
        tokenizer.save_pretrained(cfg.save_dir)

    #delete all "global_step*" files in the save_dir/checkpoint-*/ directories
    if local_rank == 0:
        for file in Path(cfg.save_dir).glob("checkpoint-*"):
            for global_step_dir in file.glob("global_step*"):
                #delete the directory
                import shutil
                shutil.rmtree(global_step_dir)

if __name__ == "__main__":
    main()