
from collections import defaultdict, Counter
import logging
import math
import numpy as np


logging.basicConfig(
    format="SystemLog: [%(asctime)s][%(name)s][%(levelname)s] - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
)

logger = logging.getLogger(__name__)

class DataManager:
    def __init__(self, dual_exec_results, sampled_code_by_task, sampled_test_case_by_task, limit):
        logger.info('handling dual exec results')
        self.dual_exec_results = dual_exec_results
        self.sampled_code_by_task = sampled_code_by_task
        self.sampled_test_case_by_task = sampled_test_case_by_task
        self.limit = limit

        self.solution_frequency_by_task = defaultdict(Counter)
        self.test_case_frequency_by_task = dict()
        self.passed_unique_solutions_by_task = defaultdict(set)
        self.passed_unique_test_cases_by_task = defaultdict(set)
        self.passed_solution_test_case_pairs_by_task = defaultdict(set)
        self.passed_solution_test_case_pairs_by_task_with_time = defaultdict(set)
        self.solution_string_to_id_range_by_task = dict()
        self.test_case_string_to_id_range_by_task = dict()
        self.solution_id_to_string_by_task = dict()
        self.test_case_id_to_string_by_task = dict()
        self.taskid_id_to_string = dict()
        self.taskid_string_to_id = dict()

        self.expanded_passed_solution_test_case_pairs_by_task = defaultdict(list)

        self._get_solution_frequency()
        logger.info('got solution frequency')
        self._get_test_case_frequency()
        logger.info('got test case frequency')
        self._get_passed_solution_test_case_pairs_by_task()
        logger.info('got passed solution test case pairs by task')
        self._get_solution_and_test_case_ids()
        logger.info('got solution and test case ids')
        self._get_expanded_dual_exec_result()
        logger.info('got expanded dual exec results')

    def _get_solution_frequency(self):
        for sample in self.sampled_code_by_task:
            task_id = sample['task_id']
            completion = sample['completion']
            self.solution_frequency_by_task[task_id][completion] += 1

    def _get_test_case_frequency(self):
        for task_id in self.sampled_test_case_by_task.keys():
            task_test_cases = [
                cases_per_sample[:self.limit] for cases_per_sample in self.sampled_test_case_by_task[task_id]
            ]
            task_test_cases = sum(task_test_cases, [])
            self.test_case_frequency_by_task[task_id] = Counter(task_test_cases)

    def _get_passed_solution_test_case_pairs_by_task(self):
        for result in self.dual_exec_results:
            if not result['passed']:
                continue
            if type(result['result']) != list:
                continue
            if type(result['result_time']) != list:
                assert False, "type(result['result']) = " + str(type(result['result'])) + ", type(result['result_time']) = " + str(type(result['result_time']))
            for idx, test_case in enumerate(result['test_cases']):
                assert len(result['result']) == len(result['result_time']), "len(result['result']), len(result['result_time']) = " + str(len(result['result'])) + ", "+str(len(result['result_time'])) + "\n" + str(result['result']) + "\n" + str(result['result_time'])
                if result['result'][idx] != True:
                    continue
                if idx >= len(result['result_time']):
                    assert False, "result['result_time'] = " + str(result['result_time']) +', error in ' + str(idx)
                if test_case not in self.test_case_frequency_by_task[result['task_id']]:
                    continue
                self.passed_solution_test_case_pairs_by_task[result['task_id']].add((result['completion'], test_case))
                self.passed_solution_test_case_pairs_by_task_with_time[result['task_id']].add((result['completion'], test_case, result['result_time'][idx]))
                self.passed_unique_solutions_by_task[result['task_id']].add(result['completion'])
                self.passed_unique_test_cases_by_task[result['task_id']].add(test_case)

    def _build_string_to_id_range(self, frequency_dict, limited_values):
        id_ranges = dict()
        start_id = 0
        for key, value in frequency_dict.items():
            if key not in limited_values:
                continue
            id_ranges[key] = range(start_id, start_id + value)
            start_id += value
        return id_ranges

    def _build_id_to_string(self, str_to_id_range):
        id_to_string = dict()
        for string in str_to_id_range.keys():
            for idx in str_to_id_range[string]:
                id_to_string[idx] = string
        return id_to_string

    def _get_solution_and_test_case_ids(self):
        for task_id in self.solution_frequency_by_task.keys():
            self.solution_string_to_id_range_by_task[task_id] = self._build_string_to_id_range(self.solution_frequency_by_task[task_id], self.passed_unique_solutions_by_task[task_id])
            self.test_case_string_to_id_range_by_task[task_id] = self._build_string_to_id_range(self.test_case_frequency_by_task[task_id], self.passed_unique_test_cases_by_task[task_id])
            self.solution_id_to_string_by_task[task_id] = self._build_id_to_string(self.solution_string_to_id_range_by_task[task_id])
            self.test_case_id_to_string_by_task[task_id] = self._build_id_to_string(self.test_case_string_to_id_range_by_task[task_id])
        all_task_ids = list(self.solution_frequency_by_task.keys())
        self.taskid_id_to_string = {i:e for i,e in enumerate(all_task_ids)}
        self.taskid_string_to_id = {self.taskid_id_to_string[k]:k for k in self.taskid_id_to_string}

    def _get_expanded_by_id_range(self, solution_id_range, test_case_id_range):
        result = list()
        for solution_id in solution_id_range:
            for test_case_id in test_case_id_range:
                result.append((solution_id, test_case_id))
        return result

    def _get_expanded_dual_exec_result(self):
        for task_id in self.passed_solution_test_case_pairs_by_task_with_time.keys():
            for solution_str, test_case_str, run_time in self.passed_solution_test_case_pairs_by_task_with_time[task_id]:
                solution_id_range = self.solution_string_to_id_range_by_task[task_id][solution_str]
                test_case_id_range = self.test_case_string_to_id_range_by_task[task_id][test_case_str]
                self.expanded_passed_solution_test_case_pairs_by_task[task_id] += self._get_expanded_by_id_range(solution_id_range, test_case_id_range)


class DualAgreement:
    def __init__(self, data_manager):
        self.data_manager = data_manager
        self.dual_exec_results_by_task = data_manager.expanded_passed_solution_test_case_pairs_by_task

        self.solution_id_to_string_by_task = data_manager.solution_id_to_string_by_task
        self.test_case_id_to_string_by_task = data_manager.test_case_id_to_string_by_task
        self.solution_string_to_id_range_by_task = data_manager.solution_string_to_id_range_by_task
        self.test_case_string_to_id_range_by_task = data_manager.test_case_string_to_id_range_by_task
        self.taskid_id_to_string = data_manager.taskid_id_to_string
        self.taskid_string_to_id = data_manager.taskid_string_to_id

        self.max_task = len(self.taskid_id_to_string)
        self.max_sol = max([len(self.solution_id_to_string_by_task[e]) for e in self.solution_id_to_string_by_task])
        self.max_test = max([len(self.test_case_id_to_string_by_task[e]) for e in self.test_case_id_to_string_by_task])

        self.solution_passed_cases_by_task = defaultdict(defaultdict)
        self.caseset_passed_solutions_by_task = defaultdict(defaultdict)

        self.task_sol_test_matrix = np.zeros((self.max_task, self.max_sol, self.max_test))

        self._get_solution_passed_case_set()
        logger.info('got solution passed case sets')
        self._get_caseset_passed_solutions()
        logger.info('got case set passed solutions')

    def _get_solution_passed_case_set(self):
        for task_id in self.dual_exec_results_by_task:
            for solution, test_case in self.dual_exec_results_by_task[task_id]:
                if solution in self.solution_passed_cases_by_task[task_id]:
                    self.solution_passed_cases_by_task[task_id][solution].append(test_case)
                else:
                    self.solution_passed_cases_by_task[task_id][solution] = [test_case]
                this_sol_id = solution
                this_test_case_id = test_case
                this_task_id = self.taskid_string_to_id[task_id]
                self.task_sol_test_matrix[this_task_id][this_sol_id][this_test_case_id] = 1

    def _get_caseset_passed_solutions(self):
        for task_id in self.solution_passed_cases_by_task.keys():
            for solution in self.solution_passed_cases_by_task[task_id].keys():
                case_set = tuple(sorted(self.solution_passed_cases_by_task[task_id][solution]))
                if case_set in self.caseset_passed_solutions_by_task[task_id]:
                    self.caseset_passed_solutions_by_task[task_id][case_set].append(solution)
                else:
                    self.caseset_passed_solutions_by_task[task_id][case_set] = [solution]

    def get_sorted_solutions_without_iter(self):
        logger.info('Start to get sorted solutions without iter')
        ranked_solutions_by_task = defaultdict(list)
        for task_id in self.caseset_passed_solutions_by_task.keys():
            flatted_case_set_passed_solutions = []
            for case_set in self.caseset_passed_solutions_by_task[task_id].keys():
                solution_set = self.caseset_passed_solutions_by_task[task_id][case_set]
                solution_set_score = math.sqrt(len(solution_set))
                case_set_score = len(case_set)
                solution_str_set = [self.solution_id_to_string_by_task[task_id][solution] for solution in solution_set]
                flatted_case_set_passed_solutions.append((solution_str_set, case_set_score*solution_set_score))
            ranked_solutions_by_task[task_id] = sorted(flatted_case_set_passed_solutions, key=lambda x: x[1], reverse=True)
        return ranked_solutions_by_task

    def get_sorted_solutions_without_iter_remove_sqrt(self):
        logger.info('Start to get sorted solutions without iter')
        ranked_solutions_by_task = defaultdict(list)
        for task_id in self.caseset_passed_solutions_by_task.keys():
            flatted_case_set_passed_solutions = []
            for case_set in self.caseset_passed_solutions_by_task[task_id].keys():
                solution_set = self.caseset_passed_solutions_by_task[task_id][case_set]
                solution_set_score = len(solution_set)
                case_set_score = len(case_set)
                solution_str_set = [self.solution_id_to_string_by_task[task_id][solution] for solution in solution_set]
                flatted_case_set_passed_solutions.append((solution_str_set, case_set_score*solution_set_score))
            ranked_solutions_by_task[task_id] = sorted(flatted_case_set_passed_solutions, key=lambda x: x[1], reverse=True)
        return ranked_solutions_by_task

    def get_sorted_solutions_with_iter_page_rank(self, T = 10, beta = 0.85, rtn_test=False):
        logger.info('Start to get sorted solutions with iter page rank')
        ranked_solutions_by_task = defaultdict(list)
        ranked_tests_by_task = defaultdict(list)

        solution_scores = np.ones((self.max_task, self.max_sol))
        test_scores = np.ones((self.max_task, self.max_test))

        def iter_step_page_rank(solution_scores_t_1, test_scores_t_1, beta):
            test_scores_t = test_scores_t_1 * (1 - beta) + np.einsum("PCT,PC->PT", self.task_sol_test_matrix, solution_scores_t_1) * beta
            solution_scores_t = solution_scores_t_1 * (1 - beta) + np.einsum("PCT,PT->PC", self.task_sol_test_matrix, test_scores_t) * beta
            return solution_scores_t, test_scores_t

        for i in range(T):
            solution_scores, test_scores = iter_step_page_rank(solution_scores, test_scores, beta)

        for task_id in self.caseset_passed_solutions_by_task.keys():
            this_taskid_id = self.taskid_string_to_id[task_id]
            rtn_sol_score_list = []
            for solution_id in self.solution_id_to_string_by_task[task_id].keys():
                solution_string = self.solution_id_to_string_by_task[task_id][solution_id]
                solution_score = solution_scores[this_taskid_id][solution_id]
                rtn_sol_score_list.append(([solution_string], solution_score))
            ranked_solutions_by_task[task_id] = sorted(rtn_sol_score_list, key=lambda x: x[1], reverse=True)

        if not rtn_test:
            return ranked_solutions_by_task
        else:
            for task_id in self.caseset_passed_solutions_by_task.keys():
                this_taskid_id = self.taskid_string_to_id[task_id]
                rtn_test_score_list = []
                for test_id in self.test_case_id_to_string_by_task[task_id].keys():
                    test_string = self.test_case_id_to_string_by_task[task_id][test_id]
                    test_score = test_scores[this_taskid_id][test_id]
                    rtn_test_score_list.append(([test_string], test_score))
                ranked_tests_by_task[task_id] = sorted(rtn_test_score_list, key=lambda x: x[1], reverse=True)
            return ranked_solutions_by_task, ranked_tests_by_task

    def get_correct_test_with_highest_code(self):
        logger.info('Start to get sorted solutions with iter page rank')
        ranked_solutions_by_task = defaultdict(list)
        ranked_tests_by_task = defaultdict(list)
        correct_tests_by_task = defaultdict(list)

        solution_scores = np.ones((self.max_task, self.max_sol))
        assert self.max_sol == 1
        test_scores = np.ones((self.max_task, self.max_test))

        test_scores = np.einsum("PCT,PC->PT", self.task_sol_test_matrix, solution_scores)

        for task_id in self.caseset_passed_solutions_by_task.keys():
            this_taskid_id = self.taskid_string_to_id[task_id]
            rtn_sol_score_list = []
            for solution_id in self.solution_id_to_string_by_task[task_id].keys():
                solution_string = self.solution_id_to_string_by_task[task_id][solution_id]
                solution_score = solution_scores[this_taskid_id][solution_id]
                rtn_sol_score_list.append(([solution_string], solution_score))
            ranked_solutions_by_task[task_id] = sorted(rtn_sol_score_list, key=lambda x: x[1], reverse=True)

        for task_id in self.caseset_passed_solutions_by_task.keys():
            this_taskid_id = self.taskid_string_to_id[task_id]
            rtn_test_score_list = []
            for test_id in self.test_case_id_to_string_by_task[task_id].keys():
                test_string = self.test_case_id_to_string_by_task[task_id][test_id]
                test_score = test_scores[this_taskid_id][test_id]
                rtn_test_score_list.append(([test_string], test_score))
            ranked_tests_by_task[task_id] = sorted(rtn_test_score_list, key=lambda x: x[1], reverse=True)
            correct_tests_by_task[task_id] = sorted(rtn_test_score_list, key=lambda x: x[1], reverse=True)
            correct_tests_by_task[task_id] = [e for e in correct_tests_by_task[task_id] if e[1] > 0.1]
        return ranked_solutions_by_task, ranked_tests_by_task, correct_tests_by_task

    def get_correct_code_with_correct_test(self):
        logger.info('Start to get sorted solutions with iter page rank')
        ranked_solutions_by_task = defaultdict(list)
        ranked_tests_by_task = defaultdict(list)
        correct_tests_by_task = defaultdict(list)

        solution_scores = np.ones((self.max_task, self.max_sol))
        test_scores = np.ones((self.max_task, self.max_test))

        solution_scores = np.einsum("PCT,PT->PC", self.task_sol_test_matrix, test_scores)
        for task_id in self.caseset_passed_solutions_by_task.keys():
            this_taskid_id = self.taskid_string_to_id[task_id]
            rtn_sol_score_list = []
            for solution_id in self.solution_id_to_string_by_task[task_id].keys():
                solution_string = self.solution_id_to_string_by_task[task_id][solution_id]
                solution_score = solution_scores[this_taskid_id][solution_id]
                rtn_sol_score_list.append(([solution_string], solution_score))
            ranked_solutions_by_task[task_id] = sorted(rtn_sol_score_list, key=lambda x: x[1], reverse=True)

        for task_id in self.caseset_passed_solutions_by_task.keys():
            this_taskid_id = self.taskid_string_to_id[task_id]
            rtn_test_score_list = []
            for test_id in self.test_case_id_to_string_by_task[task_id].keys():
                test_string = self.test_case_id_to_string_by_task[task_id][test_id]
                test_score = test_scores[this_taskid_id][test_id]
                rtn_test_score_list.append(([test_string], test_score))
            ranked_tests_by_task[task_id] = sorted(rtn_test_score_list, key=lambda x: x[1], reverse=True)
        return ranked_solutions_by_task, ranked_tests_by_task
