import os
import json
import argparse

import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer


def calculate_gradient(model, corpus, device="cuda", batch_size=1):
    model.train()
    total_loss = 0
    all_gradients_avg = {}
    max_sign_same_rate = {}
    processed_samples = 0
    
    for idx in tqdm(range(0, len(corpus), batch_size)):
        inputs = [i["input_ids"] for i in corpus[idx:idx+batch_size]]
        input_ids = torch.tensor(inputs).to(device)
        labels = [i["labels"] for i in corpus[idx:idx+batch_size]]
        labels = torch.tensor(labels).to(device)

        model.zero_grad()

        output = model(input_ids=input_ids, labels=labels)
        loss = output.loss
        print(loss.item())
        total_loss += loss.item()
        loss.backward()

        for name, param in model.named_parameters():
            if param.grad is not None:
                grad = param.grad.detach().cpu()
                if name not in all_gradients_avg:
                    all_gradients_avg[name] = torch.zeros_like(grad)
                    max_sign_same_rate[name] = torch.zeros_like(grad)
                all_gradients_avg[name] += grad
                max_sign_same_rate[name] += torch.sign(grad)

        processed_samples += len(input_ids)

    print(f"Total loss: {total_loss}")
    # 计算平均梯度和最大符号相同率
    for name in all_gradients_avg:
        all_gradients_avg[name] /= processed_samples
        max_sign_same_rate[name] = (torch.sign(max_sign_same_rate[name]), 
                                    torch.abs(max_sign_same_rate[name]) / processed_samples)

    return all_gradients_avg, max_sign_same_rate


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--corpus", type=str, required=True)
    parser.add_argument("--data_size", type=int, default=1000, help="Number of data samples to use for inference.")
    parser.add_argument("--save_dir", type=str, help="Path to save inference results.")
    parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for inference.")
    
    args = parser.parse_args()

    corpus = []
    with open(args.corpus, 'r') as f:
        for line in tqdm(f):
            corpus.append(json.loads(line))
            if len(corpus) >= args.data_size:
                break
    
    model = AutoModelForCausalLM.from_pretrained(args.model_path).to(args.device)
    if args.device == "cuda":
        model = model.half()

    all_gradients_avg, max_sign_same_rate = calculate_gradient(model, corpus, device=args.device)

    os.makedirs(args.save_dir, exist_ok=True)
    torch.save(all_gradients_avg, os.path.join(args.save_dir, "all_gradients_avg.pt"))
    torch.save(max_sign_same_rate, os.path.join(args.save_dir, "max_sign_same_rate.pt"))
