
import pandas as pd
from ast import literal_eval
from tqdm import tqdm
import numpy as np
from time import time
import os
from utils_llm_consistency import get_context, prune
import argparse
from utils import get_logger
from vllm import LLM, SamplingParams
import json
from itertools import chain, product


parser = argparse.ArgumentParser()
parser.add_argument("filename", type=str)
parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--model_name", type=str, default="llama7b_chat_hf", choices=["llama13b_hf", "llama13b_chat_hf", "llama7b_hf", "llama7b_chat_hf"])
parser.add_argument("--context_length", type=int, default=1000)
parser.add_argument("--cuda_id", type=int, default=1)
args = parser.parse_args()


save_root = "result"
logger, exp_seq, save_path = get_logger(save_root=save_root, save_tag="complex_query_consistency")
logger.info(f"=======Exp: {exp_seq}=============")
logger.info(f"Model: {args.model_name}")

# os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_id)

model_name = args.model_name


config = None
with open("../_config.json") as f:
    config = json.load(f)
if(config == None):
    quit()

llm = LLM(model=config[args.model_name]['model_path'])
sampling_params = SamplingParams(temperature=config[args.model_name]['temperature'], 
                                 max_tokens=config[args.model_name]['max_tokens'],
                                 top_k=config[args.model_name]['top_k'])



filename = args.filename
assert "/" in filename
store_filename = f"{filename.split('/')[-1].replace('data_final', 'prompt_commutative_property')}"
if(args.nrows is not None):
    store_filename = f"{store_filename.split('.csv')[0]}_{args.nrows}.csv"
assert "/" not in store_filename
total_lines = sum(1 for row in open(filename, 'r'))
logger.info(f"Total lines in {filename}: {total_lines}")
chunksize = 100
chunkidx = 0
total_chunks = total_lines//chunksize + 1
logger.info(f"Chunksize: {chunksize}")
logger.info(f"Total chunks: {total_chunks}")
with pd.read_csv(filename, chunksize=chunksize, nrows=args.nrows) as reader:
    for data in reader:
        logger.info(f"Loaded chunk {chunkidx + 1}/{total_chunks} from {filename} with shape {data.shape}")
        chunkidx += 1
        # convert to list
        for column in data.columns:
            if("subgraph" in column or "tail_entities" in column) and "time" not in column:
                data[column] = data[column].apply(lambda x: literal_eval(x.replace('nan', 'None')))

        
        num_subqueries = 0
        for column in data.columns:
            if("flipped_tail_entities" in column and column != "flipped_tail_entities"):
                num_subqueries += 1
        num_subqueries

        head_entity_columns = [column for column in data.columns if "head_entity" in column]
        relation_columns = [column.replace("head_entity", "relation") for column in head_entity_columns]
        tail_entity_columns = [column for column in data.columns if column.startswith("tail_entities")]
        # data.head(n=100)

        for index in tqdm(range(data.shape[0]), disable=True):
            # prune
            df_row_pruned = prune(data.iloc[[index]], filename=filename, max_tail_entities=1)

            
            # get context
            start_time = time()
            context = get_context(df_row_pruned.iloc[0]['subgraph'], 
                                target_head_entities=list(df_row_pruned .iloc[0][head_entity_columns].values),
                                target_relations=list(df_row_pruned .iloc[0][relation_columns].values),
                                target_tail_entities=list(set(chain.from_iterable(df_row_pruned.iloc[0][tail_entity_columns].values))),
                                relation_separator="/" if "FB15k" in filename else (":" if "NELL" in filename else None),
                                tokenizer=llm.get_tokenizer(),
                                max_context_len=args.context_length, 
                                verbose=False)
            context.drop_duplicates(keep='first', inplace=True)
            time_context = time() - start_time
            if("FB15k" in filename):
                context_as_string = ("\n").join([(" | ").join([head_entity_context, relation_context.split("/")[-1], tail_entity_context])
                                            for head_entity_context, relation_context, tail_entity_context in context.values]).replace("_", " ")
            elif("NELL" in filename):
                context_as_string = ("\n").join([(" | ").join([head_entity_context, relation_context.split(":")[-1], tail_entity_context])
                                            for head_entity_context, relation_context, tail_entity_context in context.values]).replace("_", " ")
            else:
                raise ValueError(f"{filename} not recognzied")

            if(len(llm.get_tokenizer().encode(context_as_string)) > 2*args.context_length):
                print("Warning: Context too long", index)
                print(context)
                continue


            

            instruction = "Consider the context as a set of triplets where entries are separated by '|' symbol. Answer question according to the context.\n\n"
            question_instruction_single_chain = "Do not add additional text. Is the following logic query FACTUALLY CORRECT? Answer with Yes or No.\n\n"

            # print()
            inconsistent_data = False
            # empty tail entities
            for i in range(num_subqueries):
                # print(i, df_row_pruned.iloc[0][f'tail_entities_{i+1}'], df_row_pruned.iloc[0][f'flipped_tail_entities_{i+1}'])
                if(len(df_row_pruned.iloc[0][f'tail_entities_{i+1}']) == 0):
                    inconsistent_data = True
                    break
                if(None in df_row_pruned.iloc[0][f'tail_entities_{i+1}']):
                    inconsistent_data = True
                    break
                if(None in df_row_pruned.iloc[0][f'flipped_tail_entities_{i+1}']):
                    inconsistent_data = True
                    break
                assert len(df_row_pruned.iloc[0][f'tail_entities_{i+1}']) == 1
                assert len(df_row_pruned.iloc[0][f'flipped_tail_entities_{i+1}']) == 1

            if(inconsistent_data):
                continue


            unique_tail_entities = []
            for i in range(num_subqueries):
                unique_tail_entities.append(df_row_pruned[f'tail_entities_{i+1}'].item() + df_row_pruned[f'flipped_tail_entities_{i+1}'].item())
            
            relations = []
            for i in range(num_subqueries):
                relations.append(df_row_pruned[f'relation_{i+1}'].item())
            


            for use_context in [True, False]:
                
                for tail_entity_tuple in list(product(*unique_tail_entities)):
                    result = {
                        "use_context" : use_context,
                        'prompt_base_query' : None,
                        "raw_response_base_query" : None,
                        "response_base_query" : None,
                        "ground_truth_base_query" : None,
                        "correct_base_query" : None,
                        "logically_consistent": None,
                        "time_context": None,
                        "time_response_base_query": None,
                        "prompt_base_query_reordered": None,
                        "raw_response_base_query_reordered": None,
                        "response_base_query_reordered": None,
                        "ground_truth_base_query_reordered": None,
                        "correct_base_query_reordered": None,
                        "time_response_base_query_reordered": None,
                    }


                    # add previous time info
                    for column_df_row_pruned in df_row_pruned.columns:
                        if(column_df_row_pruned.startswith("time")):
                            result[column_df_row_pruned] = df_row_pruned.iloc[0][column_df_row_pruned]

                        
                    test_triplets = []
                    test_triplets_raw = []
                    if("FB15k" in filename):    
                        for i in range(num_subqueries):
                            test_triplets.append((" | ").join((df_row_pruned[f'head_entity_{i+1}'].item(), relations[i].split("/")[-1], tail_entity_tuple[i])).replace("_", " "))
                            test_triplets_raw.append((df_row_pruned[f'head_entity_{i+1}'].item(), relations[i], tail_entity_tuple[i]))
                    elif("NELL" in filename):
                        for i in range(num_subqueries):
                            test_triplets.append((" | ").join((df_row_pruned[f'head_entity_{i+1}'].item(), relations[i].split(":")[-1], tail_entity_tuple[i])).replace("_", " "))
                            test_triplets_raw.append((df_row_pruned[f'head_entity_{i+1}'].item(), relations[i], tail_entity_tuple[i]))
                    else:
                        raise ValueError(f"{filename} not recognzied")

                    if("u_data_final" in filename):
                        prompt = question_instruction_single_chain + \
                            (" OR \n").join([f"( {test_triplet} )" for test_triplet in test_triplets])
                    elif("i_data_final" in filename):
                        prompt = question_instruction_single_chain + \
                            (" AND \n").join([f"( {test_triplet} )" for test_triplet in test_triplets])
                    else:
                        raise ValueError("Unknown filename")
                        
                    if(use_context):
                        result['time_context'] = time_context
                        prompt = instruction + \
                            context_as_string + "\n\n" + prompt

                    result["prompt_base_query"] = prompt

                    # reordered prompt
                    if("u_data_final" in filename):
                        prompt = question_instruction_single_chain + \
                            (" OR \n").join([f"( {test_triplet} )" for test_triplet in test_triplets[::-1]])
                    elif("i_data_final" in filename):
                        prompt = question_instruction_single_chain + \
                            (" AND \n").join([f"( {test_triplet} )" for test_triplet in test_triplets[::-1]])
                    else:
                        raise ValueError("Unknown filename")
                    if(use_context):
                        result['time_context'] = time_context
                        prompt = instruction + \
                            context_as_string + "\n\n" + prompt
                
                    result["prompt_base_query_reordered"] = prompt

                    
                    

                    inconsistent_context = False

                    # subqueries
                    subquery_verdict_sum = 0
                    for i in range(num_subqueries):

                        
                        mask = (context['head_entity'] == test_triplets_raw[i][0]) & \
                                (context['relation'] == test_triplets_raw[i][1]) & \
                                (context['tail_entity'] == test_triplets_raw[i][2])
                        
                        if(tail_entity_tuple[i] in df_row_pruned[f'tail_entities_{i+1}'].item()):
                            subquery_verdict_sum += 1
                            if(np.sum(mask) != 1):
                                inconsistent_context = True
                                break
                        else:
                            subquery_verdict_sum += 0
                            if(np.sum(mask) != 0):
                                inconsistent_context = True
                                break
                    
                    if(inconsistent_context):
                        logger.info(f"Warning: Assertion false for index {index + (chunkidx - 1) * chunksize}")
                        continue

                    if("u_data_final" in filename):
                        if(subquery_verdict_sum > 0):
                            result["ground_truth_base_query"] = True
                            result['ground_truth_base_query_reordered'] = True

                        else:
                            result["ground_truth_base_query"] = False
                            result['ground_truth_base_query_reordered'] = False
                    elif("i_data_final" in filename):
                        if(subquery_verdict_sum == num_subqueries):
                            result["ground_truth_base_query"] = True
                            result['ground_truth_base_query_reordered'] = True
                        else:
                            result["ground_truth_base_query"] = False
                            result['ground_truth_base_query_reordered'] = False
                    else:
                        raise ValueError("Unknown filename")

                    
                    # replace with newline
                    for key in result.keys():
                        if(result[key] is not None and ("prompt" in key or "raw_response" in key)):
                            # print(key)
                            result[key] = result[key].replace('\n', '[NEWLINE]')

                    # store results via append
                    result_df = pd.DataFrame([result])
                    if(not os.path.exists(f'{save_path}/{store_filename}')):
                        result_df.to_csv(f'{save_path}/{store_filename}', index=False)
                    else:
                        result_df.to_csv(f'{save_path}/{store_filename}', index=False, header=False, mode='a')
            