import params
import argparse
from evaluate import eval_single_file, bio_eval_single_file, load_wiki, load_truthfulqa, load_biography
import pandas as pd
import json
from sentence_transformers import SentenceTransformer
from google import genai
from inference_utils import inference


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--evaluation_type", type=str, default="gemini")
    parser.add_argument("--method", type=str, default="adaptive")
    parser.add_argument("--base_model", type=str, required=True)
    parser.add_argument("--train_data", type=str, default="truthful_qa") # data used for precompute
    parser.add_argument("--eval_data", type=str, default="truthful_qa")
    parser.add_argument("--data_split", type=str, default="test")
    parser.add_argument("--max_sample_num", type=int, default=100)
    parser.add_argument("--save_results", type=bool, default=False)
    parser.add_argument("--standard_prompt_key", type=str, default='zero_shot') # few_shot_bio for ICL
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--noisy_prompt_key", type=str, default=None, help="Key for the noisy prompt template, if not None --> Instructive Decoding")
    parser.add_argument("--embed_model_name", type=str, default="all-MiniLM-L6-v2", help="Name of the embedding model for adaptive decoding")
    parser.add_argument("--prompt_dir", type=str, default="prompt_templates/")
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--batch_size", type=int, default=160) # 160 for general QA, 32 for bio
    parser.add_argument("--device", type=str, default="cuda:0")

    return parser.parse_args()


args = get_args()

args.eval_data_path = f"{args.eval_data}_{args.base_model}_{args.method}_{args.data_split}_{args.max_sample_num}.json"
args.output_path = f"{params.output_dir}/{args.eval_data_path}"

args.gemini_client = genai.Client(api_key=params.gemini_api_key)

# Loading data
if args.eval_data == "wiki":
    data = load_wiki(split=args.data_split, max_sample_num=args.max_sample_num)
elif args.eval_data == "truthful_qa":
    data = load_truthfulqa(split=args.data_split, max_sample_num=args.max_sample_num)
elif args.eval_data == "bio":
    data = load_biography(split=args.data_split, max_sample_num=args.max_sample_num)
else:
    raise ValueError(f"Unsupported evaluation data: {args.eval_data}")
    
questions = [sample["question"] for sample in data]

# 1. Inference
print('Starting inference', args.method, args.train_data, args.base_model, args.eval_data)
if args.method == "adaptive":
    embed_model = SentenceTransformer(args.embed_model_name, device="cuda:0")

    args.adaptive_decoding_config = {
        "embed_model": embed_model,
        "model_name": args.base_model,
        "embed_model_name": args.embed_model_name,
        "train_data": args.train_data
    }

    inference(questions, args)
            
else:
    inference(questions, args)

# 2. Evaluation
print('Starting evaluation', args.method, args.train_data, args.base_model, args.eval_data)
if args.eval_data == "bio":
    result = bio_eval_single_file(args.output_path, args=args)
else:
    result = eval_single_file(args.output_path, args=args)

# 3. Save results
timestamp = round(pd.Timestamp.now().timestamp())
eval_result_file = f"{params.result_dir}/{args.method}__{args.evaluation_type}__{timestamp}.json"
with open(eval_result_file, "w") as f:
    json.dump(result, f, indent=2)

