import argparse
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
from edit import ModelEditor,load_data,MODEL_PATH_MAP,Attack_MODEL_PATH_MAP
import os
import time
from tqdm import tqdm


def main(args):
    # Load test data
    human_texts,machine_texts,prompts=load_data(args.test_data_path)
    if args.model_edit:
        # model_path=Attack_MODEL_PATH_MAP[args.model_name]
        model_path=args.edited_model_path
    else:
        model_path=MODEL_PATH_MAP[args.model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_path,device_map='auto',trust_remote_code=True)
    editor = ModelEditor(model, tokenizer, args.var_threshold,device=model.device,batch_size=args.batch_size)
    space_path=os.path.join(args.human_like_space_dir,args.model_name,'human_like_space.pt')
    editor.load_human_like_space(space_path)
    print(editor.human_like_spaces)

    if args.model_name in ['llama-13b','qwen-14b']:
        model=model.half()

    if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
        total_layers = len(model.transformer.h)
        print(f"Using transformer.h, total layers: {total_layers}")
    elif hasattr(model, 'model') and hasattr(model.model, 'layers'):
        total_layers = len(model.model.layers)
        print(f"Using model.layers, total layers: {total_layers}")
    elif hasattr(model, 'layers'):
        total_layers = len(model.layers)
        print(f"Using layers, total layers: {total_layers}")
    else:
        print("Cannot determine number of layers, model structure:", model.__class__.__name__)
    end_layer = args.start_layer + args.num_layers
    if end_layer > total_layers:
        raise ValueError(f"Requested layers [{args.start_layer}, {end_layer}) exceed total layers {total_layers}")

    results = []

    corpus=args.test_data_path.split("/")[3].split("_")[-1]
    os.makedirs(args.time_dir,exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)
    time_save_path=os.path.join(args.time_dir,'generation_time.txt')
    output_path=os.path.join(args.output_dir,f"{args.model_name}_{corpus}_{args.rep_edit}_{args.model_edit}_{args.var_threshold}_{args.start_layer}_{args.num_layers}_{args.space_samples_number}_alpha{args.alpha}.jsonl")
    print(f"data will be saved at {output_path}")
    st=time.time()
    for index,p in tqdm(enumerate(prompts)):
        inputs = tokenizer(p, return_tensors="pt").to(editor.device)
        handles = []

        if args.rep_edit:
            for layer in range(args.start_layer, end_layer):
                def hook_fn(module, inp, out, layer=layer):
                    return editor.rep_edit(out, layer,args.alpha)
                if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
                    model_layer = model.transformer.h[layer]
                elif hasattr(model, 'model') and hasattr(model.model, 'layers'):
                    model_layer = model.model.layers[layer]
                elif hasattr(model, 'layers'):
                    model_layer = model.layers[layer]
                
                handle = model_layer.register_forward_hook(hook_fn)
                handles.append(handle)

            output = model.generate(**inputs, max_new_tokens=200, pad_token_id=tokenizer.eos_token_id)

            for h in handles:
                h.remove()
        else:
            output = model.generate(**inputs, max_new_tokens=200, past_key_values=None, pad_token_id=tokenizer.eos_token_id)

        text = tokenizer.decode(output[0], skip_special_tokens=True)
        results.append({"prompt": p, "human_text": human_texts[index], "machine_text": machine_texts[index], "attack_text": text})
        with open(output_path, "a", encoding="utf-8") as f:
            res={"prompt": p, "human_text": human_texts[index], "machine_text": machine_texts[index], "attack_text": text}
            f.write(json.dumps(res, ensure_ascii=False) + "\n")
    et=time.time()
    print("Total_Time:",et-st)
    with open(time_save_path, "a", encoding="utf-8") as f:
        text=f'{args.test_data_path}: Generation Time for Model Edit ({args.model_edit}) Rep Edit ({args.rep_edit}) using {args.model_name} for sample number {args.space_samples_number} from layer {args.start_layer} to {end_layer} is {et-st}.\n'
        f.write(text)

    print(f"Saved results to {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--test_data_path", type=str, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--human_like_space_dir", type=str, default='./space')
    parser.add_argument("--start_layer", type=int, default=15)
    parser.add_argument("--num_layers", type=int, default=25)
    parser.add_argument("--var_threshold", type=float, default=0.9)
    parser.add_argument("--alpha", type=float, default=0.5)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--rep_edit", action="store_true")
    parser.add_argument("--model_edit", action="store_true")
    parser.add_argument("--edited_model_path", type=str)
    parser.add_argument("--output_dir", type=str, default="output/qwen_layer")
    parser.add_argument("--time_dir", type=str, default="./final_res")
    parser.add_argument("--space_samples_number", type=int, default=500)
    args = parser.parse_args()
    main(args)
