import os
from copy import deepcopy
import re
import pandas as pd
import numpy as np
# from torch.utils.data import Dataset as tDataset
from datasets import Dataset as hfDataset
import datasets as hfsets
from .common import compute_standard_scores, clean_up_special_chars
from bigcodebench.evaluate import check_correctness
from .base_dataset import DatasetWithExactSolution, DatasetWithReference


import tempfile, json
import time
from gradio_client import Client, handle_file
from concurrent.futures._base import CancelledError
import httpx
remote_execute_api: str = "https://bigcode-bigcodebench-evaluator.hf.space/"


class BCBDataset(DatasetWithExactSolution, DatasetWithReference):
    def __init__(
        self,
        add_few_shot_examples = 0
    ): # add more filters by source, etc if necessary
        ds = hfsets.load_dataset('bigcode/bigcodebench', split='v0.1.2')
        self.unsafe_warning_received = False
        self.data = ds
        self.add_few_shot_examples = add_few_shot_examples
        self.fs_split_str = "\n\n\n\n"

    def __getitem__(self, item):
        if self.add_few_shot_examples < 1:
            # returns structured for our rollout scripts
            dset_item = self.full_item(item)
            dset_item['x'] = dset_item['complete_prompt']
            return dset_item
        else:
            dset_item = self.full_item(item)
            dset_item['x'] = dset_item['complete_prompt']
            dset_item['fs_in'] = []
            dset_item['fs_out'] = []
            for fs_id in range(self.add_few_shot_examples):
                print(fs_id)
                # keep adding with minus index - relatively safe and straightforward
                fs_q = self.get_question(item-fs_id)[0]
                fs_a = self.get_answer(item-fs_id)[0]+self.fs_split_str
                fs_q = re.sub("^( *?import.*?| *?from.*?import.*?)\n", '', fs_q)  # remove imports to not confuse the already confused models
                fs_q = re.sub('task_func', f"unrelated_example_func_{chr(ord('a')+fs_id)}", fs_q) # rename the function to avoid confusion
                fs_a = re.sub('task_func', f"unrelated_example_func_{chr(ord('a')+fs_id)}", fs_a)
                dset_item['fs_in'].append(fs_q)
                dset_item['fs_out'].append(fs_a)
            
            return dset_item
    
    def full_item(self, item):
        return deepcopy(self.data[item])

    def __len__(self):
        return len(self.data)

    def calculate_exact_correctness(
        self, 
        item, 
        answer,  # is just the generated part
        return_dict=True, 
        max_as_limit = 100000,  # max amount of address memory (MB) for the evaluator
        max_data_limit = 100000, # max amount of memory memory (MB) for the evaluator
        max_stack_limit = 100000, # max amount of stack memory (MB) for the evaluator
        min_time_limit = 5.1,
        gt_time_limit = 5.0,
        check_reference_solution = False,  # set to true to also check the reference solution
        include_detailed_info = False,
    ):
        if not self.unsafe_warning_received:
            input("WARNING! RUNNING UNSAFE LLM GENERATED CODE! \
                SOME SANDBOXING IS THERE BUT YOU NEVER KNOW! \
                    PRESS ANY KEY TO CONTINUE:")
            self.unsafe_warning_received = True

        # get the original dataset entry
        original_dset_entry = self.full_item(item)
        # put together the answer using the prompt
        answer = original_dset_entry['complete_prompt'] + answer
        # do the unsafe
        res = check_correctness(
            completion_id = 'arbitrary',
            problem = original_dset_entry,
            solution = answer,
            max_as_limit = max_as_limit,
            max_data_limit = max_data_limit,
            max_stack_limit = max_stack_limit,
            min_time_limit = min_time_limit,
            gt_time_limit = gt_time_limit,
        )
        if check_reference_solution:
            # TODO: check reference solution
            # exclude the errors happening in those (would be the timeout errors mostly)
            # modify the pass/fail appropiately
            return NotImplemented
        
        succ, errors = res['base']
        # TODO: consider that 'timeout' can also be returned
        succ = True if succ=='pass' else False
        nerrors = len(errors)
        ntotal = len(re.findall(r'    def test_', original_dset_entry['test']))

        if return_dict:
            retdict = {
                'correct': succ,
                'nerrors': nerrors,
                'ntotal': ntotal
            }
            if include_detailed_info:
                retdict['original'] = errors
            return retdict
        else:
            return succ


    def generate_submission_transcript_from_answers(
        self,
        answers,
    ):
        considered_ids = []
        n_totals = []
        formatted_solutions = []
        for answer in answers:
            answer_idx = answer['dataset_idx']
            considered_ids.append(answer_idx)
            ds_item = self.full_item(answer_idx)
            
            task_id = ds_item['task_id']
            n_totals.append(len(re.findall(r'    def test_', ds_item['test'])))

            assert isinstance(answer['txt_xy'][0], str), "lolwut?"
            self.get_entry_by_bcb_id(task_id)
            formatted_solutions.append(
                {
                    'task_id': task_id, 
                    'solution': answer['txt_xy'][0]
                }
            )
        # fill the rest of list with the remaining ids
        for unprovided_dset_idx in [i for i in range(len(self)) if i not in considered_ids]:
            ds_item = self.full_item(unprovided_dset_idx)
            formatted_solutions.append(
                {
                    'task_id': ds_item['task_id'],
                    'solution': 'pass'
                }
            )
            
        return formatted_solutions, {
            'n_totals': n_totals,
            'considered_ids': considered_ids,
        }
    
    def pick_up_the_results(
        self,
        results,
        pass_at_k,
        n_totals,
        considered_ids,
    ):
        # use the pass@k file to get the gt (reference solutions passability)
        failed_ids = []
        if 'failed_tasks' in pass_at_k:
            # read the ids and add to the list
            for failed_bcb_id in pass_at_k['failed_tasks']:
                failed_ids.append(self.get_entry_by_bcb_id(failed_bcb_id))

        # take apart the results
        results_extracted = []
        for dset_id, ntotal in zip(considered_ids, n_totals):
            # 0 here is hardcoded, if ever need to evaluate multiple, change that
            res_entry = results['eval'][f"BigCodeBench/{dset_id}"][0]
            # to trigger an assert just in case
            self.get_entry_by_bcb_id(f"BigCodeBench/{dset_id}")
            succ = 1. if res_entry['status']=='pass' else 0.
            if dset_id in failed_ids and succ==0.: # if succeded and failed the gt, then we still pass
                succ = np.nan
            nerrors = ntotal if 'ALL' in res_entry['details'].keys() else len(res_entry['details'])

            retdict = {
                'correct': succ,
                '_nerrors': nerrors,
                '_ntotal': ntotal,
                '_failed_gt': dset_id in failed_ids,
                'dataset_idx': dset_id,
            }
            results_extracted.append(retdict)
        return results_extracted, pass_at_k


    def calculate_exact_batch_correctness(
        self,
        answers, 
        return_dict=True, 
        split: str = 'complete',
        subset: str = 'full',
        pass_k: str = "1,5,10",
        save_pass_rate: bool = True,
        calibrated: bool = True,
        parallel: int = -1,
        min_time_limit: float = 1,
        max_as_limit: int = 30*1024,
        max_data_limit: int = 30*1024,
        max_stack_limit: int = 10,
        check_gt_only: bool = False,
        no_gt: bool = False,
        check_reference_solution = False,  # set to true to also check the reference solution
        include_detailed_info = False,
        manual_upload = False,
        dir_to_preserve_intermediates=None
    ):
        formatted_solutions, info = self.generate_submission_transcript_from_answers(answers)
        n_totals = info['n_totals']
        considered_ids = info['considered_ids']

        # write into a temp file and do the server call
        with tempfile.NamedTemporaryFile('wt') as fp:
            pd.DataFrame.from_records(formatted_solutions).to_json(fp.name, orient='records', lines=True)
            # print(formatted_solutions[-1])
            if not manual_upload:
                # now run the correctness script
                while True:
                    try:
                        client = Client(remote_execute_api)
                        results, pass_at_k = client.predict(
                            split="complete",
                            subset="full",
                            samples=handle_file(fp.name),
                            pass_k="1,5,10",
                            parallel=-1,
                            min_time_limit=1,
                            max_as_limit=30720,
                            max_data_limit=30720,
                            max_stack_limit=10,
                            calibrated=True,
                            check_gt_only=False,
                            no_gt=False,
                            api_name="/predict"
                        )
                        break
                    except (httpx.ReadTimeout, CancelledError):
                        print("Read timeout error. Retrying in 10s...")
                        time.sleep(10)
            else:
                input(f"The jsonl file is located at {fp.name} . Upload it to the {remote_execute_api} manually. Press any key when done")
                input("Place the eval result file into the ./tmp_response.json file. Press any key to continue.")
                with open('./tmp_response.json', 'rt') as f:
                    results = json.load(f)
                input("Place the pass@k file into the ./tmp_response.json file. Press any key to continue.")
                with open('./tmp_response.json', 'rt') as f:
                    pass_at_k = json.load(f)

        # preserve the original evaluation results just in case
        if dir_to_preserve_intermediates is not None:
            with open(os.path.join(dir_to_preserve_intermediates, 'gradio_eval.json'), 'wt') as f:
                json.dump(results, f)
            with open(os.path.join(dir_to_preserve_intermediates, 'pass_at_k_info.json'), 'wt') as f:
                json.dump(pass_at_k, f)
        
        return self.pick_up_the_results(
            results,
            pass_at_k,
            **info
        )

    def get_answer(self, item):
        return [self.full_item(item)['canonical_solution'],]

    def get_question(self, item):
        return [self.full_item(item)['complete_prompt'],]
    
    def get_entry_by_bcb_id(self, bcb_id):
        bcb_id_dset_index = int(bcb_id.split('/')[-1])
        assert self.full_item(bcb_id_dset_index)['task_id'] == bcb_id, f"Messed up ordering on id: {bcb_id}"
        return bcb_id_dset_index

    def calculate_correctness(self, item, answer, **kwargs):
        ref_answer = self.get_answer(item)
        question = self.get_question(item)
        return compute_standard_scores(ref_answer, answer, question=question, **kwargs)

    def get_problem_system_instruction(self):
        return ''