from model import GPT, inference
import argparse 
import torch
import tiktoken
import os
import json
import tqdm
import re
import pandas as pd
from utils import get_info
import torch.multiprocessing as mp
from itertools import chain
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["TIKTOKEN_CACHE_DIR"] = "./tiktoken_cache"
enc = tiktoken.get_encoding("gpt2")

def process_chunk(process_id, gpu_id, model, qas, start_idx, end_idx, return_dict, progress_queue):
    torch.cuda.set_device(gpu_id)
    result = []
    correct = 0
    
    for i in range(start_idx, end_idx):
        qa = qas[i]
        prompt = qa.split("A:")[0] + "A:"
        question = qa.split("A:")[0].split("Q:")[1].strip()
        question_type, person = get_info(question)
        golden_answer = qa.split("A:")[1].strip()
        generated_answer = inference(model, prompt, tokenizer=enc, max_new_tokens=10, stop_token=198, temperature=0).strip()
        result.append({
            "prompt": prompt,
            "golden_answer": golden_answer,
            "generated_answer": generated_answer,
            "question_type": question_type,
            "person": person,
            "correct": (golden_answer == generated_answer)
        })
        correct += (golden_answer == generated_answer)
        progress_queue.put(1)  # Signal that one sample is processed
    
    return_dict[process_id] = (result, correct)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--input_path", type=str)
    parser.add_argument("--output_path", type=str)
    parser.add_argument("--first_n", type=int, default=-1)
    parser.add_argument("--processes_per_gpu", type=int, default=4)
    args = parser.parse_args()
    
    # Load model in master process
    print("Loading model from disk...")
    models = {}
    n_gpus = torch.cuda.device_count()
    model = GPT.from_pretrained(args.model_path, f"cuda:{0}")
    models[0] = model
    for gpu_id in range(1, n_gpus):
        models[gpu_id] = model.to(f"cuda:{gpu_id}")
    
    with open(args.input_path, "r", encoding="utf-8") as f:
        qas = f.readlines()
    
    total_samples = min(len(qas), args.first_n) if args.first_n != -1 else len(qas)
    total_processes = n_gpus * args.processes_per_gpu
    chunk_size = (total_samples + total_processes - 1) // total_processes
    
    processes = []
    manager = mp.Manager()
    return_dict = manager.dict()
    progress_queue = manager.Queue()
    
    print(f"Starting inference on {n_gpus} GPUs with {args.processes_per_gpu} processes per GPU")
    for gpu_id in range(n_gpus):
        for proc_idx in range(args.processes_per_gpu):
            process_id = gpu_id * args.processes_per_gpu + proc_idx
            start_idx = process_id * chunk_size
            end_idx = min(start_idx + chunk_size, total_samples)
            
            if start_idx >= total_samples:
                continue
                
            p = mp.Process(
                target=process_chunk,
                args=(process_id, gpu_id, models[gpu_id], qas, start_idx, end_idx, return_dict, progress_queue)
            )
            processes.append(p)
            p.start()
    
    # Monitor progress with tqdm
    pbar = tqdm.tqdm(total=total_samples, desc="Total Progress")
    completed = 0
    while completed < total_samples:
        progress_queue.get()  # Wait for a signal
        completed += 1
        pbar.update(1)
    pbar.close()
    
    for p in processes:
        p.join()
    
    # Combine results from all processes
    all_results = []
    total_correct = 0
    
    for process_id in return_dict:
        result, correct = return_dict[process_id]
        all_results.extend(result)
        total_correct += correct
    
    accuracy = total_correct / total_samples
    print(f"Accuracy: {accuracy}")
    
    df = pd.DataFrame(all_results)
    r = {}
    with open(args.output_path, "w", encoding="utf-8") as f:
        r["accuracy"] = accuracy
        r["detail_result"] = df.groupby(["question_type"]).correct.mean().to_dict()
        json.dump(r, f, indent=4)
    df.to_csv(args.output_path.replace(".json", ".csv"), index=False)

if __name__ == "__main__":
    mp.set_start_method('spawn')
    main()
# example command:
# python inference_SFT.py --model_path ckpt/SFT_mix_pretrain_10x.pt --input_path bioS_single/SFT.txt --output_path result_SFT_inference_1k.json --first_n 1000 --processes_per_gpu 4
