import json
from abc import abstractmethod

from tqdm import tqdm

from src.utils.tools import generate_ramdom_sequence
import pandas as pd


class TaskOrigentedDataset:
    def __init__(self, data_path):
        self.data_path = data_path
        self.task = None
        self.dataset = self.load_dataset()

    def load_dataset(self):
        # Load the dataset from the given data path
        data = []
        with open(self.data_path, 'r') as file:
            for line in file:
                json_line = json.loads(line)
                data.append(json_line)
        return data

    def __getitem__(self, item):
        return self.dataset[item]

    def __len__(self):
        return len(self.dataset)

    @abstractmethod
    def generate_task_specific_dataset(self):
        pass

    @abstractmethod
    def show_specific_dataset_statistic(self, specific_dataset):
        pass


class SingleInferenceValidityAssessmentDataset(TaskOrigentedDataset):
    def __init__(self, data_path):
        super().__init__(data_path)
        self.task = 'Single Inference Validity Assessment'

    def generate_task_specific_dataset(self):
        specific_dataset = []
        print(f'Generating {self.task} dataset from {self.data_path}...')
        for data in tqdm(self.dataset):
            if data is None:
                continue
            # P -> C
            sdata = {
                'Precondition': data['Precondition'],
                'Consequent': data['Consequent'],
                'Correctness': data['Correctness_Orig'],
                'Source_ID': data['ID'],
                'ID': f'SIVA_P2C_{generate_ramdom_sequence(20)}'
            }

            specific_dataset.append(sdata)
            # J + P -> C

            if data['Justification']:
                sdata = {
                    'Precondition': data['Justification']+ ' ' + data['Precondition'],
                    'Consequent': data['Consequent'],
                    'Correctness': data['Correctness_Target'],
                    'Source_ID': data['ID'],
                    'ID': f'SIVA_PJ2C_{generate_ramdom_sequence(20)}'
                }

                specific_dataset.append(sdata)

        print(f'Generated {len(specific_dataset)} {self.task} samples')
        print(f'Generated {self.task} dataset from {self.data_path} successfully')
        return specific_dataset

    def show_specific_dataset_statistic(self, specific_dataset):
        df = pd.DataFrame(specific_dataset)
        c_count = df['Correctness'].value_counts()
        print('#samples\t#T\t#F')
        print(f'{len(df)}\t{int(c_count[True]) if True in c_count else 0}\t{int(c_count[False]) if False in c_count else 0}')

        return {
            '#samples': len(df),
            '#T': int(c_count[True]) if True in c_count else 0,
            '#F': int(c_count[False]) if False in c_count else 0,
        }

class MultipleChoiceInferenceValidityAssessmentDataset(TaskOrigentedDataset):
    def __init__(self, data_path):
        super().__init__(data_path)
        self.task = 'Multiple Choice Inference Validity Assessment'

    def generate_task_specific_dataset(self):
        specific_dataset = []
        print(f'Generating {self.task} dataset from {self.data_path}...')

        precondition_dict = {}
        for data in tqdm(self.dataset):
            if data is None:
                continue
            # P -> C
            precondition = data['Precondition']

            if precondition not in precondition_dict:
                precondition_dict[precondition] = []

            precondition_dict[precondition].append(data)

        for precondition, datas in precondition_dict.items():
            options = []
            for data in datas:
                options.append({
                    'Consequent': data['Consequent'],
                    'Label': data['Correctness_Orig'],
                    'Justification': None,
                })

                options.append({
                    'Consequent': data['Consequent'],
                    'Label': data['Correctness_Target'],
                    'Justification': data['Justification'],
                })
                # Remove duplicate data
                options = [eval(op) for op in set([str(op) for op in options])]
            sdata = {
                'Precondition': precondition,
                'Options': options,
                'Source_ID': data['ID'],
                'ID': f'MCIVA_{generate_ramdom_sequence(20)}'
            }

            specific_dataset.append(sdata)
        print(f'Generated {len(specific_dataset)} {self.task} samples')
        print(f'Generated {self.task} dataset from {self.data_path} successfully')
        return specific_dataset

    def show_specific_dataset_statistic(self, specific_dataset):
        df = pd.DataFrame(specific_dataset)
        options_avg = df['Options'].apply(lambda x: len(x)).mean()
        options_true_avg = df['Options'].apply(lambda x: len([i for i in x if i['Label'] == True])).mean()
        options_false_avg = df['Options'].apply(lambda x: len([i for i in x if i['Label'] == False])).mean()
        print('#samples\tavg.options\tavg.options.true\tavg.options.false')
        print(f'{len(df)}\t{options_avg:.1f}\t{options_true_avg:.1f}\t{options_false_avg:.1f}')

        return {
            '#samples': len(df),
            'avg.options': options_avg,
            'avg.options.true': options_true_avg,
            'avg.options.false': options_false_avg,
        }

class ConsequentGenerationDataset(TaskOrigentedDataset):
    def __init__(self, data_path):
        super().__init__(data_path)
        self.task = 'Consequent Generation'

    def generate_task_specific_dataset(self):
        specific_dataset = []
        print(f'Generating {self.task} dataset from {self.data_path}...')
        for data in tqdm(self.dataset):
            if data is None:
                continue
            # P -> C
            if data['Correctness_Orig']:
                sdata = {
                    'Precondition': data['Precondition'],
                    'Consequent': data['Consequent'],
                    'Source_ID': data['ID'],
                    'ID': f'CG_P2C_{generate_ramdom_sequence(20)}'
                }
                specific_dataset.append(sdata)

            # J + P -> C
            if data['Justification']:
                sdata = {
                    'Precondition': data['Justification']+ ' ' + data['Precondition'],
                    'Consequent': data['Consequent'],
                    'Source_ID': data['ID'],
                    'ID': f'CG_PJ2C_{generate_ramdom_sequence(20)}'
                }

                specific_dataset.append(sdata)

        print(f'Generated {len(specific_dataset)} {self.task} samples')
        print(f'Generated {self.task} dataset from {self.data_path} successfully')
        return specific_dataset

    def show_specific_dataset_statistic(self, specific_dataset):
        df = pd.DataFrame(specific_dataset)
        print('#samples')
        print(f'{len(df)}')

        return {
            '#samples': len(df),
        }

class NegativeDefaultJustificationGenerationDataset(TaskOrigentedDataset):
    def __init__(self, data_path):
        super().__init__(data_path)
        self.task = 'Negative Default Justification Generation'

    def generate_task_specific_dataset(self):
        specific_dataset = []
        print(f'Generating {self.task} dataset from {self.data_path}...')
        for data in tqdm(self.dataset):
            if data is None:
                continue
            # J + P -> C
            if data['Correctness_Target'] ^ data['Correctness_Orig']:
                sdata = {
                    'Precondition': data['Precondition'],
                    'Justification': data['Justification'],
                    'Consequent': data['Consequent'],
                    'Correctness': data['Correctness_Target'],
                    'Source_ID': data['ID'],
                    'Type': data['type'],
                    'ID': f'NDJG_{generate_ramdom_sequence(20)}'
                }

                specific_dataset.append(sdata)

        print(f'Generated {len(specific_dataset)} {self.task} samples')
        print(f'Generated {self.task} dataset from {self.data_path} successfully')
        return specific_dataset

    def show_specific_dataset_statistic(self, specific_dataset):
        df = pd.DataFrame(specific_dataset)
        if len(df) == 0:
            return {
                '#samples': 0,
                '#F2T': 0,
                '#T2F': 0
            }

        type_count = df['Type'].value_counts()
        print('#samples\t#F2T\t#T2F')
        if 'F2T' not in type_count:
            type_count['F2T'] = 0

        if 'T2F' not in type_count:
            type_count['T2F'] = 0
        print(f'{len(df)}\t{int(type_count["F2T"])}\t{int(type_count["T2F"])}')

        return {
            '#samples': len(df),
            '#F2T': int(type_count["F2T"]),
            '#T2F': int(type_count["T2F"])
        }


class NegativeDefaultJustificationSelectionDataset(TaskOrigentedDataset):
    def __init__(self, data_path):
        super().__init__(data_path)
        self.task = 'Negative Default Justification Selection'

    def generate_task_specific_dataset(self):
        specific_dataset = []
        print(f'Generating {self.task} dataset from {self.data_path}...')

        precondition_consequent_dict = {}
        for data in tqdm(self.dataset):
            if data is None:
                continue
            # P C -> J
            precondition = data['Precondition']
            consequent = data['Consequent']
            precondition_consequent = f'{precondition} {consequent}'

            if precondition_consequent not in precondition_consequent_dict:
                precondition_consequent_dict[precondition_consequent] = []

            precondition_consequent_dict[precondition_consequent].append(data)

        for precondition_consequent, datas in precondition_consequent_dict.items():
            options = []
            for data in datas:
                options.append({
                    'Label': data['Correctness_Target'],
                    'Justification': data['Justification'],
                    'Type': data['type'],
                })
                # Remove duplicate data
                options = [eval(op) for op in set([str(op) for op in options])]
                sdata = {
                    'Precondition': data['Consequent'],
                    'Consequent': data['Consequent'],
                    'Options': options,
                    'Source_ID': data['ID'],
                    'Correctness': data['Correctness_Orig'],
                    'ID': f'MCIVA_{generate_ramdom_sequence(20)}'
                }

                specific_dataset.append(sdata)
        print(f'Generated {len(specific_dataset)} {self.task} samples')
        print(f'Generated {self.task} dataset from {self.data_path} successfully')
        return specific_dataset

    def show_specific_dataset_statistic(self, specific_dataset):
        df = pd.DataFrame(specific_dataset)
        c_count = df['Correctness'].value_counts()
        options_avg = df['Options'].apply(lambda x: len(x)).mean()
        options_type_count = df['Options'].apply(lambda x: [i['Type'] for i in x]).explode('col1').value_counts()
        print('#samples\t#T\t#F')
        print(f'{len(df)}\t{int(c_count[True]) if True in c_count else 0}\t{int(c_count[False]) if False in c_count else 0}')

        print('avg.options\t#options\t'+'\t'.join(options_type_count.keys()))
        print(f'{options_avg:.1f}\t{df['Options'].apply(lambda x: len(x)).sum()}\t'+'\t'.join([str(i) for i in options_type_count.tolist()]))

        return {
            '#samples': len(df),
            '#T': int(c_count[True]) if True in c_count else 0,
            '#F': int(c_count[False]) if False in c_count else 0,
            'avg.options': options_avg,
            '#options': df['Options'].apply(lambda x: len(x)).sum(),
            '#T2T': int(options_type_count['T2T'] if 'T2T' in options_type_count else 0),
            '#T2F': int(options_type_count['T2F'] if 'T2F' in options_type_count else 0),
            '#F2T': int(options_type_count['F2T'] if 'F2T' in options_type_count else 0),
            '#F2F': int(options_type_count['F2F'] if 'F2F' in options_type_count else 0),
        }



class DefaultJustificationGenerationDataset(TaskOrigentedDataset):
    def __init__(self, data_path):
        super().__init__(data_path)
        self.task = 'Default Justification Generation'

    def generate_task_specific_dataset(self):
        specific_dataset = []
        print(f'Generating {self.task} dataset from {self.data_path}...')
        for data in tqdm(self.dataset):
            if data is None:
                continue
            # J + P -> C
            if not (data['Correctness_Target'] ^ data['Correctness_Orig']):
                sdata = {
                    'Precondition': data['Precondition'],
                    'Justification': data['Justification'],
                    'Consequent': data['Consequent'],
                    'Correctness_Orig': data['Correctness_Orig'],
                    'Correctness_Target': data['Correctness_Target'],
                    'Source_ID': data['ID'],
                    'Type': data['type'],
                    'ID': f'DJG_{generate_ramdom_sequence(20)}'
                }

                specific_dataset.append(sdata)

        print(f'Generated {len(specific_dataset)} {self.task} samples')
        print(f'Generated {self.task} dataset from {self.data_path} successfully')
        return specific_dataset

    def show_specific_dataset_statistic(self, specific_dataset):
        df = pd.DataFrame(specific_dataset)
        type_count = df['Type'].value_counts()
        print('#samples\t#T2T\t#F2F')
        if 'T2T' not in type_count:
            type_count['T2T'] = 0

        if 'F2F' not in type_count:
            type_count['F2F'] = 0

        print(f'{len(df)}\t{int(type_count["T2T"])}\t{int(type_count["F2F"])}')

        return {
            '#samples': len(df),
            '#T2T': int(type_count['T2T']),
            '#F2F': int(type_count['F2F']),
        }