import os
import sys
import ast
import json 
import yaml
import argparse
from tqdm import tqdm
import torch
from pathlib import Path
from random import random
from dataclasses import dataclass
from typing import Literal, Optional, Union, Tuple
import random
from rich import print
import time
from vllm import LLM, SamplingParams
from src.utils import print_gpu_usage


TEMPLATE_NAMES = [
    "no_conflict_zero_shot_1",
    "conflict_zero_shot_1",
    "no_conflict_cot_1",
    "conflict_cot_1",
    "no_conflict_cot_explicit_1",
    "conflict_cot_explicit_1",
    "no_conflict_zero_shot_2",
    "conflict_zero_shot_2",
    "no_conflict_cot_2",
    "conflict_cot_2",
    "no_conflict_cot_explicit_2",
    "conflict_cot_explicit_2",
    "no_conflict_zero_shot_3",
    "conflict_zero_shot_3",
    "no_conflict_cot_3",
    "conflict_cot_3",
    "no_conflict_cot_explicit_3",
    "conflict_cot_explicit_3",
    "no_conflict_zero_shot_4",
    "conflict_zero_shot_4",
    "no_conflict_cot_4",
    "conflict_cot_4",
    "no_conflict_cot_explicit_4",
    "conflict_cot_explicit_4",
    "no_conflict_zero_shot_5",
    "conflict_zero_shot_5",
    "no_conflict_cot_5",
    "conflict_cot_5",
    "no_conflict_cot_explicit_5",
    "conflict_cot_explicit_5",
    "no_conflict_zero_shot_6",
    "conflict_zero_shot_6",
    "no_conflict_cot_6",
    "conflict_cot_6",
    "no_conflict_cot_explicit_6",
    "conflict_cot_explicit_6",
    "no_conflict_zero_shot_7",
    "conflict_zero_shot_7",
    "no_conflict_cot_7",
    "conflict_cot_7",
    "no_conflict_cot_explicit_7",
    "conflict_cot_explicit_7",
    "no_conflict_zero_shot_8",
    "conflict_zero_shot_8",
    "no_conflict_cot_8",
    "conflict_cot_8",
    "no_conflict_cot_explicit_8",
    "conflict_cot_explicit_8",
    "no_conflict_zero_shot_9",
    "conflict_zero_shot_9",
    "no_conflict_cot_9",
    "conflict_cot_9",
    "no_conflict_cot_explicit_9",
    "conflict_cot_explicit_9",
    "no_conflict_zero_shot_10",
    "conflict_zero_shot_10",
    "no_conflict_cot_10",
    "conflict_cot_10",
    "no_conflict_cot_explicit_10",
    "conflict_cot_explicit_10",
    "no_conflict_zero_shot_11",
    "conflict_zero_shot_11",
    "no_conflict_cot_11",
    "conflict_cot_11",
    "no_conflict_cot_explicit_11",
    "conflict_cot_explicit_11",
    "no_conflict_zero_shot_12",
    "conflict_zero_shot_12",
    "no_conflict_cot_12",
    "conflict_cot_12",
    "no_conflict_cot_explicit_12",
    "conflict_cot_explicit_12",
]

def inference(args, queries=None):
    model_name = args.model_name    
    llm = LLM(model=model_name, 
              trust_remote_code=True,
              tensor_parallel_size=4)
    print(f"llm: {llm}")
    
    tokenizer = llm.get_tokenizer()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    output_file = output_dir / "Llama3-8B_output.jsonl"
    
    with open(output_file, "w") as f:
        for prompt_name, prompt in tqdm(queries.items()):
            sampleparams = SamplingParams(
                temperature=0.7,
                top_p=0.9,
                max_tokens=1024 if "cot" in prompt_name else 10,
                stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")],
            )
            conversations = tokenizer.apply_chat_template(
                [{'role': 'user', 'content': prompt}],
                tokenize=False,
            )
            outputs = llm.generate(
                [conversations],
                sampleparams
            )
            for out in outputs:
                generated_text = out.outputs[0].text
            print(f"generated_text: {generated_text}")
            answers = []
            if "zero_shot" in prompt_name:
                if "yes" in generated_text.lower():
                    answer = "yes"
                    answers.append(answer)
                elif "no" in generated_text.lower():
                    answer = "no"
                    answers.append(answer)
                else:
                    answer = "unknown"
                    answers.append(answer)
            elif "cot" in prompt_name:
                answer = generated_text
                answers.append(answer)
            output = dict(
                prompt_name=prompt_name,
                generated_text=generated_text,
                prompt=prompt,
                answers=answers
                
            )
            print(f"output: {output}")
            print_gpu_usage()
            f.write(json.dumps(output) + "\n")
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="redacted")
    parser.add_argument("--output_dir", type=str, default="redacted")
    args = parser.parse_args()
    
    queries = {}
    with open("redacted/configs/problem.yaml", "r") as file:
        template = yaml.safe_load(file)
        
    for template_name in TEMPLATE_NAMES:
        queries[template_name] = template[template_name]
    
    inference(args, queries)
    
    print(f"[bold green]Inference completed![/bold green]")