# 单独对模型进行单批推理

import argparse
import json
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class InferenceDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {"text": self.texts[idx], "label": self.labels[idx]}

class DynamicBatchCollator:
    def __init__(self, tokenizer, max_length=8192):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, features):
        batch_max_length = min(
            max(len(self.tokenizer.tokenize(f["text"])) for f in features),
            self.max_length
        )
        
        batch = {"input_ids": [], "attention_mask": [], "labels": []}
        
        for f in features:
            encoding = self.tokenizer(
                f["text"],
                truncation=True,
                max_length=batch_max_length,
                padding=False
            )
            # 手动填充
            pad_len = batch_max_length - len(encoding["input_ids"])
            input_ids = encoding["input_ids"] + [self.tokenizer.pad_token_id] * pad_len
            attention_mask = encoding["attention_mask"] + [0] * pad_len
            
            batch["input_ids"].append(input_ids)
            batch["attention_mask"].append(attention_mask)
            batch["labels"].append(f["label"])
        
        return {
            "input_ids": torch.tensor(batch["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(batch["attention_mask"], dtype=torch.long),
            "labels": torch.tensor(batch["labels"], dtype=torch.long)
        }

def evaluate(model, data_loader, device):
    model.eval()
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            inputs = {
                "input_ids": batch["input_ids"].to(device),
                "attention_mask": batch["attention_mask"].to(device),
                "labels": batch["labels"].to(device)
            }
            outputs = model(**inputs)
            preds = torch.argmax(outputs.logits, dim=1)
            total_correct += (preds == inputs["labels"]).sum().item()
            total_samples += inputs["labels"].size(0)
    
    return total_correct / total_samples

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", required=True, help="Path to trained model")
    parser.add_argument("--mode", choices=["interactive", "evaluate"], required=True)
    parser.add_argument("--eval_file", help="Path to evaluation JSON file")
    parser.add_argument("--sample_ratio", type=float, default=1.0, help="Fraction of samples to use for evaluation")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = AutoModelForSequenceClassification.from_pretrained(args.model_path)
    model = model.to(device).eval()
    
    collator = DynamicBatchCollator(tokenizer)

    if args.mode == "interactive":
        print("Entering interactive mode (type 'exit' to quit)")
        while True:
            try:
                text = input("\nInput query: ")
                if text.lower() == "exit":
                    break
                
                features = [{"text": text, "label": 0}]
                batch = collator(features)
                
                with torch.no_grad():
                    outputs = model(
                        input_ids=batch["input_ids"].to(device),
                        attention_mask=batch["attention_mask"].to(device)
                    )
                
                logits = outputs.logits.cpu().numpy()[0]
                pred_class = np.argmax(logits)
                confidence = torch.softmax(outputs.logits, dim=1)[0].cpu().numpy()
                
                print("\nModel output:")
                print(f"Predicted category: {pred_class + 1}")
                print(f"Confidence: {np.max(confidence):.4f}")
                print(f"Raw logits: {logits}")
                print(f"All confidences: {np.round(confidence, 4)}")
                print(outputs.logits)
            
            except Exception as e:
                print(f"Error: {str(e)}")

    elif args.mode == "evaluate":
        with open(args.eval_file, "r") as f:
            eval_data = json.load(f)
        
        if args.sample_ratio < 1.0:
            sample_size = int(len(eval_data) * args.sample_ratio)
            eval_data = np.random.choice(eval_data, size=sample_size, replace=False).tolist()
        
        texts = []
        labels = []
        for item in eval_data:
            texts.append(str(item["query"]))
            labels.append(int(item["category"]) - 1)
        
        dataset = InferenceDataset(texts, labels)
        data_loader = DataLoader(
            dataset,
            batch_size=16,
            collate_fn=collator,
            shuffle=False,
            num_workers=4
        )
        
        accuracy = evaluate(model, data_loader, device)
        print(f"\nEvaluation Results ({len(dataset)} samples):")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Correct samples: {int(accuracy * len(dataset))}")
        print(f"Total samples: {len(dataset)}")

if __name__ == "__main__":
    main()

# python /mnt/new_cpfs/yqs/LLaVA-MPE/inductor/inf_eval.py --model_path /mnt/new_cpfs/yqs/model/HuggingFaceTB/SmolLM2-360M-Instruct/results/5e-5-3e/checkpoint-25395/ --mode interactive
# python /mnt/new_cpfs/yqs/LLaVA-MPE/inductor/inf_eval.py --model_path /mnt/new_cpfs/yqs/model/HuggingFaceTB/SmolLM2-135M-Instruct/results/5e-5-3e/ --mode interactive
# python /mnt/new_cpfs/yqs/LLaVA-MPE/inductor/inf_eval.py --model_path /mnt/new_cpfs/yqs/model/HuggingFaceTB/SmolLM2-360M-Instruct/results/5e-5-3e/checkpoint-25395/ --mode evaluate --eval_file /mnt/new_cpfs/yqs/model/HuggingFaceTB/gen_query/bench_dataset/q_39.json --sample_ratio 1
