## 优化了代码，将公共部分放入utils
from utils import *
import argparse
import IPython
import torch.nn.functional as F

def compute_el2n(dd, model, tokenizer, image_processor):
    sources = dd.copy()
    sources['conversations'] = sources['conversations']
    sources = [sources]
    sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]),)
    res = preprocess_v1(
                sources,
                tokenizer,
                has_image=('image' in dd))
    input_id = res['input_ids'].cuda()
    label = res['labels'].cuda()
    image_root = 'YOUR_IMAGE_ROOT'

    if 'image' in dd:
        image = load_image(os.path.join(image_root, dd['image']))
        # Similar operation in model_worker.py
        image_tensor = process_images([image], image_processor, model.config)
        image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    else:
        image_tensor = None
        
    with torch.inference_mode():
        outputs = model.forward(
            input_id,
            images=image_tensor,
            labels=label)

    _, _, _, _, _, new_labels = model.prepare_inputs_labels_for_multimodal(input_ids = input_id, position_ids = None, 
                                                                          attention_mask = None, past_key_values=None, labels=label, images=image_tensor, image_sizes=None)
    logits = outputs.logits
    labels = new_labels
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    
    e2lns = []
    for idx in range(shift_labels.size(1)):
        if shift_labels[0][idx] == -100:
            continue
        cur_label = F.one_hot(shift_labels[0][idx], num_classes=32000)
        cur_logit = shift_logits[0][idx]
        cur_logit = F.softmax(cur_logit, dim=0)
        e2ln = torch.linalg.vector_norm(cur_logit - cur_label, 2)
        e2lns.append(e2ln)
    return torch.tensor(e2lns).mean()


if __name__ == '__main__':
    parser = argparse.ArgumentParser("", add_help=True)
    parser.add_argument("--num_chunks", type=int, default=1)
    parser.add_argument("--chunk_idx", type=int, default=0)
    parser.add_argument("--input_file", type=str)
    parser.add_argument("--model_path", type=str, default="")
    parser.add_argument("--model_base", type=str, default="")
    parser.add_argument("--output_file", type=str, default='output_chunks')
    args = parser.parse_args()
    
    target_dir = args.output_file
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)
        print("目录已创建:", target_dir)
    else:
        print("目录已存在:", target_dir)
        
    ## loading data and model, setting conversation format and project dimension
    data = json.load(open(args.input_file, 'r'))
    tokenizer, model, image_processor, _ = load_pretrained_model_lora(model_path=args.model_path, model_base=args.model_base, model_name='llava_lora')
    tokenizer.model_max_length = 2048
    conv_mode = "llava_v1"

    chunk_data = get_chunk(data, args.num_chunks, args.chunk_idx)
    el2ns = []
    for d in tqdm(chunk_data):
        el2n = compute_el2n(d, model, tokenizer, image_processor)
        el2ns.append(float(el2n))
    
    # for d in tqdm(chunk_data):
    #     grad = compute_gradient(d)
    #     grads = torch.stack([grad])
    #     projected_grads = proj.project(grads, model_id=0)
    #     all_grads.append(projected_grads.cpu())

    json.dump(el2ns, open(f"{target_dir}/output_el2n_{args.chunk_idx}.json", 'w'), indent=4)