import os
import openai
import sys
import pandas as pd

module_path = os.path.abspath(
    os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.environ['OPENAI_API_KEY'] = <Enter you key here>

import re
import json

from IPython.core.display import HTML
from functools import partial

from utils import ProgramGenerator, ProgramInterpreter
from tasks.extract_data import create_prompt

prompter = partial(create_prompt,method='all',func_pool=False, example_type = "comb_tasks") #Consider all task examples
generator = ProgramGenerator(prompter=prompter)
all_df = pd.read_csv("data/STReason_FullDataset.csv")

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    cache_dir="llama_base",
    device_map='auto',
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf",
                                          cache_dir="llama_base")

def get_llama2_reponse(prompt, max_new_tokens=4096):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature= 0.7)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Assuming the answer follows an "A:" delimiter or similar in your prompt
    answer = response.split("A:")[-1] if "A:" in response else response
    return answer.strip()

gen_programs = []
queries = []
ground_truth_programs = []
ground_truth_answers = []
gen_ans_llama2 = []

for i in range(len(all_df)):
    query = all_df["Question"][i]
    ground_truth_program = all_df["Program"][i]
    ground_truth_answer = all_df["Answer"][i]

    try:
        gen_prog, _ = generator.generate(dict(question=query))
        # print(query)
        # print(gen_prog)

        interpreter = ProgramInterpreter()
        init_state = {}
        result, prog_state, summary_text = interpreter.execute(gen_prog, init_state, inspect=True)
        data_input = result
        # print(result)

        if i<50:
            prompt = (f"Q:You are an expert in spatio temporal data analysis. {query}.\n"
                      f"{data_input}\n"
            f"A:")
            gen_answer = get_llama2_reponse(prompt, max_new_tokens=4096)

        elif i<100:
            prompt = (f"Q:You are an expert in spatiotemporal data forecasting. {query}.\n"
                      f"{data_input}\n"
            f"A:")
            gen_answer = get_llama2_reponse(prompt, max_new_tokens=4096)


        else:
            prompt = (f"Q:You are an expert in spatiotemporal data anomaly detection. {query}.\n"
                      f"{data_input}\n"
            f"A:")
            gen_answer = get_llama2_reponse(prompt, max_new_tokens=4096)
        # print(gen_answer)

        gen_programs.append(gen_prog)
        gen_ans_llama2.append(gen_answer)
        queries.append(query)
        ground_truth_programs.append(ground_truth_program)
        ground_truth_answers.append(ground_truth_answer)
        print(f"Answer generated at {i}")

    except Exception as e:
        print(f"Error processing query at {i}: {str(e)}")
        print(f"Query:{query}")
        print(f"Generated Program:{gen_prog}")
        continue

print(len(queries))
print(len(ground_truth_programs))
print(len(ground_truth_answers))
print(len(gen_programs))
print(len(gen_ans_llama2))

# Creating the DataFrame
llama2_df = pd.DataFrame({
    'Query': queries,
    'Program': ground_truth_programs,
    'Answer': ground_truth_answers,
    'Gen_Program':gen_programs,
    'Gen_Answer':gen_ans_llama2})

llama2_df.to_csv('data/llama2_answers.csv',index=False)