import os
import sys
import argparse
import yaml
import logging
import re
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import math

from auditor_utils import *

labels_iil = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L"]

correct_answers = {
    0: 'A',
    1: 'B',
    2: 'A',
    3: 'A',
    4: 'B',
    5: 'C',
    6: 'D',
    7: 'A',
    8: 'A',
    9: 'E',
    10: 'A',
    11: 'B',
    12: 'D',
    13: 'C',
    14: 'F',
    15: 'F',
    16: 'A',
    17: 'G',
    18: 'H',
    19: 'B'
}

model_name = "Qwen/Qwen2.5-32B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

llm = LLM(model=model_name, 
          tensor_parallel_size=torch.cuda.device_count(), 
          gpu_memory_utilization=0.9,
          max_model_len = 12800,
          )

def calc_iil(correct_ans, prior_logprobs, posterior_logprobs):
    idx = tokenizer.convert_tokens_to_ids(correct_ans)
    try:
        prior_logprob = prior_logprobs[idx].logprob
    except:
        prior_logprob = -float('inf')
    try:
        posterior_logprob = posterior_logprobs[idx].logprob
    except:
        posterior_logprob = -float('inf')
    print(f"Posterior Logprob: {posterior_logprob}\nPrior Logprob: {prior_logprob}")
    try:
        for lbl in labels_iil:
           print(f"Posterior Logprob of Label {lbl}: {posterior_logprobs[tokenizer.convert_tokens_to_ids(lbl)].logprob}")
    except:
        print("Error in accessing posterior logprobs for labels")
    return (math.e**posterior_logprob)*(posterior_logprob - prior_logprob)
    
if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="0.yaml")
    parser.add_argument("--start_index", type=int, default=0)
    parser.add_argument("--start_offset", type=int, default=0)
    args = parser.parse_args()

    with open(args.config, 'r', encoding='utf-8') as config_file:
        config = yaml.safe_load(config_file)

    input = config['input']
    adversary_role = config['adversary_role']
    target_attribute = config['target_attribute']
    
    try:
        start_index = config['start_index']
    except:
        start_index = args.start_index

    with open(input, 'r', encoding='utf-8') as file:
        data = yaml.safe_load(file)

    output_file = input.replace(".yaml", "_predictions_consistency_iil.out")
    logging.basicConfig(filename=output_file, level=logging.INFO, format='%(message)s')

    prior_logprobs = None
    for person_index in data.keys():
        correct_ans = correct_answers[int(person_index)]
        get_prior = False
        # if int(person_index)<args.start_index:
        #     continue 
        if int(person_index)<start_index + args.start_offset:
            continue
        elif int(person_index)==start_index:
            get_prior = True
        conv_history_list = []
        summary_list = []
        iteration_data = data[person_index]
        for i in range(len(iteration_data)//2):
            print(f'Person Index: {person_index}, Iteration: {i}')
            conv_history_list.extend([iteration_data[2*i], iteration_data[(2*i)+1]])
            if i > 100:
                break
            logging.info(f"\nFor Iteration {person_index}, Round: {i}")
            logging.info("\nUser: %s\nAgent: %s", iteration_data[2*i], iteration_data[(2*i)+1])
            summary, prior_logprobs_new, posterior_logprobs = gen_adv_summary_consistency_iil(conv_history_list, llm, buffer_threshold = 5, target_attribute = target_attribute, adversary_role=adversary_role, seed = person_index, counter = i, get_prior = get_prior)
            if get_prior:
                prior_logprobs = prior_logprobs_new
            else:
                print('Using stored prior logprobs')
            iil = calc_iil(correct_ans, prior_logprobs, posterior_logprobs)
            logging.info("\nSummary: ",summary,'\n IIL: <iil>', iil,'</iil>')
            print('\n IIL: <iil>', iil,'</iil>')