import re
from abc import abstractmethod
from argparse import Namespace

import pandas as pd
from tqdm import tqdm
import string
import random
import json
import os

from src.llm.utils.gpt_utils import OpenAI_API
from src.llm.utils.prompt_utils import Prompt_hepler, load_prompt
from src.utils.tools import generate_ramdom_sequence, get_handled_result, simply_re_search


class DefaultReasoningDataset:
    """Class for the default reasoning dataset"""

    def __init__(self, data_path, dtype):
        self.data_path = data_path
        self.dtype = dtype
        self.dataset = None
        self.id2idx = {}

    def __getitem__(self, item):
        return self.dataset[item]

    def __len__(self):
        return len(self.dataset)

    @abstractmethod
    def show_statistic_information(self):
        pass

    @abstractmethod
    def get_samples_for_default_reasoning(self, item: int) -> list[dict]:
        pass

    def get_index_by_id(self, id: str) -> int:
        return self.id2idx[id]

    def handle_result4LLM(self, text: str, input_data: dict, prompt: dict) -> dict:
        # 提取匹配的内容
        try:
            precondition = re.search(r'Precondition:(.*?[.])', text, re.IGNORECASE).group(1)
            consequent = re.search(r'Consequent:(.*?[.])', text, re.IGNORECASE).group(1)
            justification = re.search(r'Justification:(.*?[.])', text, re.IGNORECASE).group(1)
        except:
            # 抽取失败
            return None

        # 有其中一个为None，则样例无效
        if not (precondition and consequent and justification):
            return None
        else:
            precondition = precondition.strip()
            consequent = consequent.strip()
            justification = justification.strip()
        '''
        Source_ID: 来源样例的id
        ID: 随机唯一ID
        '''
        return {'Precondition': precondition, 'Consequent': consequent, 'Justification': justification,
                'Correctness_Orig': input_data['correctness'], 'Correctness_Target': prompt['correctness'],
                'Source_ID': input_data['ID'], 'ID': generate_ramdom_sequence(20),
                'type': f'{str(input_data["correctness"])[0]}2{str(prompt["correctness"])[0]}'}

    def read_prediction(self, prediction_path: str) -> list[dict]:
        data = []
        with open(prediction_path, 'r') as file:
            for line in file:
                json_line = json.loads(line)
                data.append(json_line)
        return data

    def save_one_prediction(self, current_data, IO, prediction, LLM_response):
        IO.write(json.dumps(prediction) + '\n')
        IO.flush()
        return prediction

    def get_result_by_LLM(self, data, api: OpenAI_API, args: Namespace, **kwargs):
        ph = Prompt_hepler()
        prompt = load_prompt(args.prompt_path, args.prompt_id)
        handled_result = None
        for _ in range(args.error_extraction_count):
            msg = ph.replace_with_dict(prompt['content'], data, '{', '}')
            try:
                llm_result = api.chat_without_history(msg)
            except Exception as e:
                print('error:', e)
                continue
            # llm_result = api.chat_without_history(msg)

            handled_result = self.handle_result4LLM(llm_result, data, prompt)
            if handled_result:
                break
        return handled_result


class LabelClassification(DefaultReasoningDataset):
    def __init__(self, data_path, dtype, **kwargs):
        super().__init__(data_path, dtype)
        self.name = 'LabelClassification'
        self.dataset = self.load_data()

        """Load the data from the specified path"""
        # print data path and type

    def load_data(self):
        """Load the data from the specified path"""
        # print data path and type
        print(f'Loading data from {self.name}')
        print('Loading data from {}'.format(self.data_path))
        print('Loading data type: {}'.format(self.dtype))

        odataset = pd.read_json(self.data_path, lines=True)
        dataset = []
        for index, odata in tqdm(odataset.iterrows()):
            data = {
                'id': odata['id'],
                'facts': odata['facts'],
                'rules': odata['rules'],
                'queries': odata['queries'],
            }

            dataset.append(data)
            self.id2idx[data['id']] = index

        print('Loading data complete!')
        return dataset

    def get_samples_for_default_reasoning(self, item) -> list[dict]:
        default_reasoning_samples = []
        if isinstance(item, slice):
            start = item.start
            stop = item.stop
            if start is None:
                start = 0

            if stop is None:
                stop = len(self)

            for i in range(start, stop):
                default_reasoning_samples += self.get_samples_for_default_reasoning(i)
            return default_reasoning_samples

        data = self[item]
        for query in data['queries']:
            default_reasoning_samples.append({
                'id': data['id'],
                'facts': data['facts'],
                'rules': data['rules'],
                'query': query['query'],
                'label': query['label'],
            })

        return default_reasoning_samples

    def get_result_by_LLM(self, data, api: OpenAI_API, args: Namespace, **kwargs):
        # get_handled_result(api, ph, data, prompt, error_extraction_count, handle_result4LLM)
        ph = Prompt_hepler()
        # precondition
        facts = data['facts']
        rules = data['rules']
        query = data['query']

        msg = {
            'facts': '\n'.join(facts),
            'rules': '\n'.join(rules),
            'query': query,
        }

        prompt = load_prompt(args.prompt_path, args.prompt_id)
        result = get_handled_result(api, ph, msg, prompt, args.error_extraction_count)
        judge = None
        if result is not None:
            judge = simply_re_search(r'<answer>(.*?)</answer>', result, re.IGNORECASE)
        else:
            return {
                'Source_ID': data['id'],
                'raw_response': 'None',
                'prediction': 'None',
                'label': data['label'],
            }

        if judge is None:
            lines = [line.strip() for line in result.split('\n') if line.strip()][::-1]
            for line in lines:
                line = line.lower()
                true_index = line.find('true')
                false_index = line.find('false')
                maybe_index = line.find('maybe')

                if max(true_index, false_index, maybe_index) < 0:
                    continue

                flags = [(true_index, 'true'), (false_index, 'false'), (maybe_index, 'maybe')]
                flags = sorted(flags, key=lambda x: x[0], reverse=True)
                judge = flags[0][1]
                break
        else:
            judge = judge.lower()

        return {
            'Source_ID': data['id'],
            'raw_response': result,
            'prediction': judge,
            'label': data['label'],
        }



class LabelNMClassification(LabelClassification):
    def __init__(self, data_path, dtype, **kwargs):
        super().__init__(data_path, dtype)
        self.name = 'LabelNMClassification'

    def load_data(self):
        """Load the data from the specified path"""
        # print data path and type
        print(f'Loading data from {self.name}')
        print('Loading data from {}'.format(self.data_path))
        print('Loading data type: {}'.format(self.dtype))

        odataset = pd.read_json(self.data_path, lines=True)
        dataset = []
        for index, odata in tqdm(odataset.iterrows()):
            data = {
                'id': odata['Source_ID'],
                'facts': odata['facts'],
                'irrelated_facts': odata['irrelated_facts'],
                'rules': odata['rules'],
                'irrelated_rules': odata['irrelated_rules'],
                'queries': odata['queries'],
            }

            dataset.append(data)
            self.id2idx[data['id']] = index

        print('Loading data complete!')
        return dataset

    def get_samples_for_default_reasoning(self, item) -> list[dict]:
        default_reasoning_samples = []
        if isinstance(item, slice):
            start = item.start
            stop = item.stop
            if start is None:
                start = 0

            if stop is None:
                stop = len(self)

            for i in range(start, stop):
                default_reasoning_samples += self.get_samples_for_default_reasoning(i)
            return default_reasoning_samples

        data = self[item]
        for query in data['queries']:
            qid = data['id']
            if 'new_fact' in query:
                qid += '_new'
            else:
                qid += '_orig'

            facts = [*data['facts']]
            if 'new_fact' in query:
                facts.append(query['new_fact'])
            default_reasoning_samples.append({
                'id': qid,
                'facts': facts + data['irrelated_facts'],
                'rules': data['rules'] + data['irrelated_rules'],
                'query': query['query'],
                'label': query['label'],
            })

        return default_reasoning_samples

    def get_result_by_LLM(self, data, api: OpenAI_API, args: Namespace, **kwargs):
        # get_handled_result(api, ph, data, prompt, error_extraction_count, handle_result4LLM)
        ph = Prompt_hepler()
        # precondition
        facts = data['facts']
        rules = data['rules']
        query = data['query']

        msg = {
            'facts': '\n'.join(facts),
            'rules': '\n'.join(rules),
            'query': query,
        }

        prompt = load_prompt(args.prompt_path, args.prompt_id)
        result = get_handled_result(api, ph, msg, prompt, args.error_extraction_count)
        judge = None
        if result is not None:
            judge = simply_re_search(r'<answer>(.*?)</answer>', result, re.IGNORECASE)
        else:
            return {
                'Source_ID': data['id'],
                'raw_response': 'None',
                'prediction': 'None',
                'label': data['label'],
            }

        if judge is None:
            lines = [line.strip() for line in result.split('\n') if line.strip()][::-1]
            for line in lines:
                line = line.lower()
                true_index = line.find('true')
                false_index = line.find('false')
                maybe_index = line.find('maybe')

                if max(true_index, false_index, maybe_index) < 0:
                    continue

                flags = [(true_index, 'true'), (false_index, 'false'), (maybe_index, 'maybe')]
                flags = sorted(flags, key=lambda x: x[0], reverse=True)
                judge = flags[0][1]
                break
        else:
            judge = judge.lower()

        return {
            'Source_ID': data['id'],
            'raw_response': result,
            'prediction': judge,
            'label': data['label'],
        }

class LabelClassificationSubset(LabelClassification):
    def __init__(self, data_path, dtype, **kwargs):
        super().__init__(data_path, dtype)
        self.name = 'LabelClassificationSubset'

    def load_data(self):
        """Load the data from the specified path"""
        # print data path and type
        print(f'Loading data from {self.name}')
        print('Loading data from {}'.format(self.data_path))
        print('Loading data type: {}'.format(self.dtype))

        odataset = pd.read_json(self.data_path, lines=True)
        dataset = []
        for index, odata in tqdm(odataset.iterrows()):
            dataset.append(odata)
            self.id2idx[odata['id']] = index

        print('Loading data complete!')
        return dataset

    def get_samples_for_default_reasoning(self, item) -> list[dict]:
        default_reasoning_samples = []
        if isinstance(item, slice):
            start = item.start
            stop = item.stop
            if start is None:
                start = 0

            if stop is None:
                stop = len(self)

            for i in range(start, stop):
                default_reasoning_samples.append(self.get_samples_for_default_reasoning(i))
            return default_reasoning_samples

        data = self[item].to_dict()
        return data

    def get_result_by_LLM(self, data, api: OpenAI_API, args: Namespace, **kwargs):
        # get_handled_result(api, ph, data, prompt, error_extraction_count, handle_result4LLM)
        ph = Prompt_hepler()
        # precondition
        facts = data['facts']
        rules = data['rules']
        query = data['query']

        msg = {
            'facts': '\n'.join(facts),
            'rules': '\n'.join(rules),
            'query': query,
        }

        prompt = load_prompt(args.prompt_path, args.prompt_id)
        result = get_handled_result(api, ph, msg, prompt, args.error_extraction_count)
        judge = None
        if result is not None:
            matches = list(re.finditer(r'<answer>(.*?)</answer>', result, re.IGNORECASE))
            if matches:
                # 获取最后一个匹配的位置
                judge = matches[-1].group(1)
        else:
            return {
                'Source_ID': data['id'],
                'raw_response': 'None',
                'prediction': 'None',
                'label': data['label'],
            }

        if judge is None:
            lines = [line.strip() for line in result.split('\n') if line.strip()][::-1]
            for line in lines:
                line = line.lower()
                true_index = line.find('true')
                false_index = line.find('false')
                maybe_index = line.find('maybe')

                if max(true_index, false_index, maybe_index) < 0:
                    continue

                flags = [(true_index, 'true'), (false_index, 'false'), (maybe_index, 'maybe')]
                flags = sorted(flags, key=lambda x: x[0], reverse=True)
                judge = flags[0][1]
                break
        else:
            judge = judge.lower()

        return {
            'Source_ID': data['id'],
            'raw_response': result,
            'prediction': judge,
            'label': data['label'],
        }

class AnswerSetGeneration(DefaultReasoningDataset):
    def __init__(self, data_path, dtype, **kwargs):
        super().__init__(data_path, dtype)
        self.name = 'AnswerSetGeneration'
        self.dataset = self.load_data()

    def load_data(self):
        """Load the data from the specified path"""
        # print data path and type
        print(f'Loading data from {self.name}')
        print('Loading data from {}'.format(self.data_path))
        print('Loading data type: {}'.format(self.dtype))

        odataset = pd.read_json(self.data_path, lines=True)
        dataset = []
        for index, odata in tqdm(odataset.iterrows()):
            data = {
                'id': odata['id'],
                'facts': odata['facts'],
                'rules': odata['rules'],
                'answers': odata['answers'],
                'answer_types': odata['answer_types'],
            }

            dataset.append(data)
            self.id2idx[data['id']] = index

        print('Loading data complete!')
        return dataset

    def get_samples_for_default_reasoning(self, item):
        default_reasoning_samples = []
        if isinstance(item, slice):
            start = item.start
            stop = item.stop
            if start is None:
                start = 0

            if stop is None:
                stop = len(self)

            for i in range(start, stop):
                default_reasoning_samples.append(self.get_samples_for_default_reasoning(i))
            return default_reasoning_samples

        data = self[item]

        return {
            'id': data['id'],
            'facts': data['facts'],
            'rules': data['rules'],
            'answers': data['answers'],
            'answer_types': data['answer_types'],
        }

    def handle_result4LLM(self, text: str, input_data: dict, prompt: dict) -> dict:
        try:
            pattern = r"([A-Za-z_, ]+ (?:is|are) (?:not )?[A-Za-z_]+);"
            matches_with_comma = re.findall(pattern, text)
            prediction = list(set(matches_with_comma))
            if not prediction:
                pattern = r"([A-Za-z-].*?\));"
                matches_with_comma = re.findall(pattern, text)
                prediction = list(set(matches_with_comma))
        except:
            prediction = None
        return {
            'Source_ID': input_data['id'],
            'raw_response': text,
            'prediction': prediction,
            'label': input_data['answers'],
            'answer_types': input_data['answer_types'],
        }
    def get_result_by_LLM(self, data, api: OpenAI_API, args: Namespace, **kwargs):
        # get_handled_result(api, ph, data, prompt, error_extraction_count, handle_result4LLM)
        ph = Prompt_hepler()
        # precondition
        facts = data['facts']
        rules = data['rules']

        msg = {
            'facts': '\n'.join(facts),
            'rules': '\n'.join(rules),
        }

        prompt = load_prompt(args.prompt_path, args.prompt_id)
        result = get_handled_result(api, ph, msg, prompt, args.error_extraction_count)

        return self.handle_result4LLM(result, data, prompt)