import os 
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import torch,json
from transformers import (
    AutoConfig,
    AutoTokenizer,
    LlamaForCausalLM,
    GenerationConfig,
)
import argparse
import numpy as np
from moepeft import Gate,find_and_replace
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from datasets import load_from_disk
import json,os
from tqdm import tqdm
from prompter import format,format_squad
from utils import load_moe_model
OUT_BASE="/mnt/mednas/jiangyinuo.jyn/code/evaluation/results"

def get_model(base_path,expert_num,adapter_name,mode,ckpt_path=None,layer_type="moe"):
  config = AutoConfig.from_pretrained(base_path, trust_remote_code=True)
  tokenizer = AutoTokenizer.from_pretrained(base_path)
  device_map = {"":int(os.environ.get("LOCAL_RANK") or 0)}
  model = LlamaForCausalLM.from_pretrained(
        model_path,
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map=device_map,
    )
  model = prepare_model_for_int8_training(model)
  model.eval()
  model.config.adapter_path=None
  if ckpt_path is not None:
    model=load_moe_model(mode,ckpt_path,model,adapter_name,expert_num,layer_type)
    model.config.adapter_path=ckpt_path
  return model,tokenizer
@torch.no_grad()
def infer(model,tokenizer,test_name):
  task=test_name.split("/")[-2]
  split=test_name.split("/")[-1]
  model_name=model.config._name_or_path.split("/")[-1]
  adapter_path= model.config.adapter_path
  if adapter_path:
    model_name=f"{adapter_path.split('/')[-3]}"
    adapter_date=f"{adapter_path.split('/')[-1]}"
    model_name=f"{model_name}_{adapter_date}"
  out_dir=f"{OUT_BASE}/{task}/{model_name}"
  if not os.path.exists(out_dir): 
    os.makedirs(out_dir)
  out_file_name=f"{out_dir}/{split}.jsonl"
  test=load_from_disk(test_name)
  test=format(test,"OneKE")
  ### infering
  print(split)
  for i in tqdm(range(len(test))):
    sintruct=test[i]["sentence"]
    input_ids = tokenizer.encode(sintruct, max_length=2048,return_tensors="pt").cuda()
    input_length = input_ids.size(1)
    generation_output = model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_length=1024, max_new_tokens=512, truncation=True,return_dict_in_generate=True), pad_token_id=tokenizer.eos_token_id)
    generation_output = generation_output.sequences[0]
    generation_output = generation_output[input_length:]
    res = tokenizer.decode(generation_output, skip_special_tokens=True)
    with open(out_file_name, 'a', encoding='utf-8') as file:
        file.write(json.dumps({"answer":res,"gd":test[i]["entities"]}, ensure_ascii=False) + '\n')
if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='inferring')
  parser.add_argument('--test', type=str, default="")
  parser.add_argument('--ckpt', type=str, default="")
  args = parser.parse_args()
  model_path = '/mnt/gruntdata/rs_nas/workspace2/yinuo/models/OneKE'
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  device_map = {"":int(os.environ.get("LOCAL_RANK") or 0)}
  ckpt_path=args.ckpt
  expert_num=16
  test=args.test
  mode="hie"
  adapter_name="id"
  model,tokenizer=get_model(model_path,expert_num,adapter_name,mode,ckpt_path=ckpt_path,layer_type="moe")
  test_name=f"/mnt/prev_nas/jiangyinuo.jyn/code/evaluation/NER/{test}"
  infer(model,tokenizer,test_name)
