import os
import json
import argparse

import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer


def inference(model, corpus, device="cuda", batch_size=1):

    model.eval()
    with torch.no_grad():
        inference_results = []
        for idx in tqdm(range(0, len(corpus), batch_size)):
            inputs = [i["input_ids"] for i in corpus[idx:idx+batch_size]]
            input_ids = torch.tensor(inputs).to(device)
            labels = [i["labels"] for i in corpus[idx:idx+batch_size]]
            labels = torch.tensor(labels).to(device)

            # 获取模型输出
            output = model(input_ids=input_ids, labels=labels.to(device))
            logits = output.logits
            print(logits.shape)
            
            # 仅针对outputs部分的logits进行操作
            logits = logits[:, torch.where(labels != -100)[1], :].squeeze(0)
            probabilities = F.softmax(logits, dim=-1)

            inference_results.append(probabilities.cpu().numpy())

            # # 提取每个位置的最大概率
            # top_probs, top_indices = torch.topk(probabilities, max_prob_count, dim=-1, largest=True, sorted=True)
            # results = {"probs": top_probs[0].tolist(), "indices": top_indices[0].tolist()}
            # inference_results.append(results)
    
    return inference_results


def js_divergence(probs_p, probs_q):
    probs_p = torch.tensor(probs_p)
    probs_q = torch.tensor(probs_q)
    eps = 1e-16
    probs_p = torch.clamp(probs_p, eps, 1)
    probs_q = torch.clamp(probs_q, eps, 1)
    
    pq = (probs_p + probs_q) / 2
    pq = torch.clamp(pq, eps, 1)

    js_p = probs_p * torch.log(probs_p / pq)
    js_q = probs_q * torch.log(probs_q / pq)
    
    total_js_divergence = 0.5 * (js_p + js_q).sum(dim=1).mean().item()
    
    return total_js_divergence


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Calculate js Divergence between two models for each sentence in corpus.")
    parser.add_argument("--model1", type=str, required=True, help="Path or name of first model.")
    parser.add_argument("--model2", type=str, required=True, help="Path or name of second model.")
    parser.add_argument("--corpus", type=str, required=True, help="Path to corpus JSON file.")
    parser.add_argument("--data_size", type=int, default=100, help="Number of data samples to use for inference.")
    parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference.")
    
    args = parser.parse_args()

    corpus = []
    with open(args.corpus, 'r') as f:
        for line in tqdm(f):
            corpus.append(json.loads(line))
            if len(corpus) >= args.data_size:
                break

    model1 = AutoModelForCausalLM.from_pretrained(args.model1).to(args.device)
    if args.device != "cpu":
        model1 = model1.half()

    model2 = AutoModelForCausalLM.from_pretrained(args.model2).to(args.device)
    if args.device != "cpu":
        model2 = model2.half()

    js_divergences = []
    for sample in corpus:
        inf_res1 = inference(model1, [sample], args.device)[0]
        inf_res2 = inference(model2, [sample], args.device)[0]
        js_div = js_divergence(inf_res1, inf_res2)
        js_divergences.append(js_div)

    data = {
        "model1": args.model1,
        "model2": args.model2,
        "corpus": args.corpus,
        "data_size": args.data_size,
        "js_divergences": js_divergences,
        "mean_js_divergence": np.mean(js_divergences),
        "std_js_divergence": np.std(js_divergences)
    }

    with open("js_divergences.json", 'a') as f:
        f.write(json.dumps(data) + '\n')
