from tqdm import tqdm

from llm import LLM
from utils_proof.file import *
from jinja2 import Template

class GraphDataset:
    def __init__(self, dataset_name, llm_name=None, require_frr=False) -> None:
        self.dataset_name = dataset_name
        self.dataset_path = f'processed_dataset/{dataset_name}'
        self.trains, self.tests = load_json(self.dataset_path + '/train.json'), load_json(self.dataset_path + '/test.json')
        self.llm_name = llm_name
        self.llm = LLM(llm_name)
        if require_frr:
            self.add_frrs()
    
    def write_dataset(self):
        output_train = f'output/{self.dataset_name}/{self.llm_name}/train.json'
        load_json(output_train, self.trains)
        output_test = f'output/{self.dataset_name}/{self.llm_name}/test.json'
        load_json(output_test, self.tests)
    
    def construct_prompt(self, datas, mode, prompt_path=None):
        if prompt_path is None:
            prompt_path = f'prompts/{self.dataset_name}/{mode}.txt'
        prompt = read_txt(prompt_path)
        template = Template(prompt)
        prompt = template.render(**datas)
        return prompt
    
    def add_frrs(self):
        train_save_path = f'frrs/{self.dataset_name}/{self.llm_name}/train.json'
        prompts_train = [self.construct_prompt(data, 'train') for data in self.trains]
        self.llm.inference(prompts_train, train_save_path)
        frr_list = load_json(train_save_path)
        for i, frr in enumerate(frr_list):
            self.trains[i]['frr'] = frr
        
        test_save_path = f'frrs/{self.dataset_name}/{self.llm_name}/test.json'
        prompts_test = [self.construct_prompt(data, 'test') for data in self.tests]
        self.llm.inference(prompts_test, test_save_path)
        frr_list = load_json(test_save_path)
        for i, frr in enumerate(frr_list):
            self.tests[i]['frr'] = frr
    
    def generate_wrong_graph(self, transform=None):
        train_save_path = f'frrs/{self.dataset_name}/{self.llm_name}/train.json'
        train_frrs = load_json(train_save_path)
        wrong_list = []
        n_train_errors = 0
        for i, tc in tqdm(enumerate(train_frrs), total=len(train_frrs)):
            try:
                transform(tc)
            except:
                n_train_errors += 1
                wrong_list.append(i)
        prompts = [self.construct_prompt(self.trains[i], 'train') for i in wrong_list]
        preds = self.llm.generate(prompts, 0.5)
        for i, wrong_id in enumerate(wrong_list):
            train_frrs[wrong_id] = preds[i]
        write_json(train_save_path, train_frrs)
        
        test_save_path =f'frrs/{self.dataset_name}/{self.llm_name}/test.json'
        test_frrs = load_json(test_save_path)
        n_test_errors = 0
        wrong_list = []
        for i, tc in tqdm(enumerate(test_frrs), total=len(test_frrs)):
            try:
                transform(tc)
            except:
                n_test_errors += 1
                wrong_list.append(i)
        prompts = [self.construct_prompt(self.tests[i], 'test') for i in wrong_list]
        preds = self.llm.generate(prompts, 0.5)
        for i, wrong_id in enumerate(wrong_list):
            test_frrs[wrong_id] = preds[i]
        write_json(test_save_path, test_frrs)
        print(f"Train Errors: {n_train_errors}, Test Errors: {n_test_errors}")
        n_wrong = n_train_errors + n_test_errors
        return n_wrong
    
    def extract_frr(self, find_frr):
        train_save_path = f'frrs/{self.dataset_name}/{self.llm_name}/train.json'
        train_frrs = load_json(train_save_path)
        new_train_frrs = []
        for i, tc in tqdm(enumerate(train_frrs), total=len(train_frrs)):
            frr_found = find_frr(tc, 'train')
            new_train_frrs.append(frr_found)
        write_json(train_save_path, new_train_frrs)
        
        test_save_path = f'frrs/{self.dataset_name}/{self.llm_name}/test.json'
        test_frrs = load_json(test_save_path)
        new_test_frrs = []
        for i, tc in tqdm(enumerate(test_frrs), total=len(test_frrs)):
            frr_found = find_frr(tc, 'test')
            new_test_frrs.append(frr_found)
        write_json(test_save_path, new_test_frrs)