import os
import sys
import ast
import json 
import yaml
import argparse
from tqdm import tqdm
import torch
from pathlib import Path
from random import random
from dataclasses import dataclass
from typing import Literal, Optional, Union, Tuple
import random
from rich import print
import time
from vllm import LLM, SamplingParams


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="redacted")
    parser.add_argument("--output_dir", type=str, default="redacted")
    parser.add_argument("--task_name", type=str, default="glue_rte")
    parser.add_argument("--base_model_name", type=str, default="gpt2")
    args = parser.parse_args()
    
    if args.task_name == "glue_rte":
        with open('redacted/glue_rte/datasets_920.json') as f:
            dataset = json.load(f)
    elif args.task_name == "glue_mnli":
        with open('redacted/glue_mnli/datasets_6704.json') as f:
            dataset = json.load(f)
    elif args.task_name == "snli":
        with open('redacted/snli/datasets_6690.json') as f:
            dataset = json.load(f)
    
    else:
        raise ValueError(f"Task name {args.task_name} not recognized.")
    print(f"len(dataset): {len(dataset)}")
    
    model_name = args.model_name    
    llm = LLM(model=model_name, 
              trust_remote_code=True,
              tensor_parallel_size=4)
    print(f"llm: {llm}")
    tokenizer = llm.get_tokenizer()
    sampleparams= SamplingParams(
                    temperature=0.0,
                    max_tokens=1024,
                    stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")],  # KEYPOINT HERE
                )
    
    output_dir = Path(args.output_dir, f"{args.task_name}_{args.base_model_name}")
    output_dir.mkdir(parents=True, exist_ok=True)
    output_file = output_dir / "outputs.jsonl"
    # eval_file = output_dir / "eval.jsonl"

    with open(output_file, "w") as f: 
        for datapoint in tqdm(dataset, total=len(dataset)):
            # print(f"datapoint: {datapoint}")
            # input()
            prompt_names = []
            conversations = []
            answers = []
            for prompt_name, prompt in datapoint.items():
                # print(f"prompt_name: {prompt_name}")
                # print(f"prompt: {prompt}")
                # input()
                prompt_names.append(prompt_name)
                conversations.append(tokenizer.apply_chat_template(
                     [{'role': 'user', 'content': prompt}],
                    tokenize=False,
                ))
                
            outputs = llm.generate(
                    conversations,
                    sampleparams
                    )
            completions = dict()
            
            for out in outputs:
                generated_text = out.outputs[0].text
                print(f"generated_text: {generated_text}")
                prompt_name = prompt_names.pop(0)
                completions[prompt_name] = generated_text
                print(f"completions: {completions}")
            
            output = dict(
                datapoint=datapoint,
                completions=completions
            )
            f.write(json.dumps(output) + "\n")
    
            
            
            
            
            
                