import fire 
from transformers import AutoModelForSequenceClassification, AutoModelForCausalLM, AutoTokenizer
import torch


def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)



def save_model(model_type = 'policy_sft', model_name = "EleutherAI/pythia-1b-deduped", model_path = '/workspace/rlhf-code/pythia1b-oai-summary-ppo-1ep.pt', save_path = None):
    if model_type in ['reward', 'policy_sft']:  
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if model_type == 'reward':
            model = AutoModelForSequenceClassification.from_pretrained(
                    model_name, torch_dtype=torch.bfloat16, num_labels=1)
        elif model_type == 'policy_sft':
            model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)

                
        
        state_dict = torch.load(model_path, map_location='cpu',weights_only=False)
        step, metrics = state_dict['step_idx'], state_dict['metrics']
        missing_keys, unexpected_keys = model.load_state_dict(state_dict['state'], strict=False)
        model.push_to_hub(save_path)
        tokenizer.push_to_hub(save_path)
    elif model_type == 'policy_ppo':

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        base_model = AutoModelForCausalLM.from_pretrained(model_name)
        peft_config = LoraConfig(
            task_type="CAUSAL_LM",
            r=8,
            target_modules=find_all_linear_names(base_model),
            lora_alpha=32,
            lora_dropout=0.0,
            bias="none",
            use_rslora=False,
            modules_to_save=None,
        )
        model = get_peft_model(base_model, peft_config)

        state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
        model.load_state_dict(state_dict['state'], strict=True)
        model = model.merge_and_unload()
        model.save_pretrained(save_path)
        tokenizer.save_pretrained(save_path)
        model.push_to_hub(save_path) 
        tokenizer.push_to_hub(save_path)

    
if __name__ == "__main__":
    fire.Fire(save_model)