## 优化了代码，将公共部分放入utils
from utils import *
import argparse
import IPython


def compute_ppl(dd, model, tokenizer, image_processor):
    sources = dd.copy()
    sources['conversations'] = sources['conversations']
    sources = [sources]
    sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]),)
    res = preprocess_v1(
                sources,
                tokenizer,
                has_image=('image' in dd))
    input_id = res['input_ids'].cuda()
    label = res['labels'].cuda()
    image_root = 'YOUR_IMAGE_ROOT'

    if 'image' in dd:
        image = load_image(os.path.join(image_root, dd['image']))
        # Similar operation in model_worker.py
        image_tensor = process_images([image], image_processor, model.config)
        image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    else:
        image_tensor = None
        
    with torch.inference_mode():
        outputs = model.forward(
            input_id,
            images=image_tensor,
            labels=label)
        neg_log_likelihood = outputs.loss
    ppl = torch.exp(neg_log_likelihood)
    return ppl


if __name__ == '__main__':
    parser = argparse.ArgumentParser("", add_help=True)
    parser.add_argument("--num_chunks", type=int, default=1)
    parser.add_argument("--chunk_idx", type=int, default=0)
    parser.add_argument("--input_file", type=str)
    parser.add_argument("--model_path", type=str, default="")
    parser.add_argument("--model_base", type=str, default="")
    parser.add_argument("--output_file", type=str, default='output_chunks')
    args = parser.parse_args()
    
    target_dir = args.output_file
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)
        print("目录已创建:", target_dir)
    else:
        print("目录已存在:", target_dir)
        
    ## loading data and model, setting conversation format and project dimension
    data = json.load(open(args.input_file, 'r'))
    tokenizer, model, image_processor, _ = load_pretrained_model_lora(model_path=args.model_path, model_base=args.model_base, model_name='llava_lora')
    tokenizer.model_max_length = 2048
    conv_mode = "llava_v1"

    chunk_data = get_chunk(data, args.num_chunks, args.chunk_idx)
    all_ppls = []
    for d in tqdm(chunk_data):
        ppl = compute_ppl(d, model, tokenizer, image_processor)
        all_ppls.append(float(ppl))
    
    # for d in tqdm(chunk_data):
    #     grad = compute_gradient(d)
    #     grads = torch.stack([grad])
    #     projected_grads = proj.project(grads, model_id=0)
    #     all_grads.append(projected_grads.cpu())

    json.dump(all_ppls, open(f"{target_dir}/output_ppl_{args.chunk_idx}.json", 'w'), indent=4)