import json
import random
from pathlib import Path
import os
import math


def get_context_allowed_shots(dataset='subj', context_len=1024):
    if context_len == 1024:
        allowed_shots = {'agnews': 2, 'carer': 4, 'mr': 8, 'mrpc': 4, 'sst5': 4, 'subj': 8, 'trec': 8, 'webss': 2}
    else:
        allowed_shots = {'agnews': 4, 'carer': 8, 'mr': 16, 'mrpc': 8, 'sst5': 8, 'subj': 16, 'trec': 16, 'webss': 4}
    assert dataset in allowed_shots
    n_demo_shot = allowed_shots[dataset]
    return n_demo_shot


def load_dataset(data_dir='../data', dataset='subj'):
    if dataset == 'subj':  # B
        AutoDataset = SUBJDataset
    elif dataset == 'mr':  # B
        AutoDataset = MRDataset
    elif dataset == 'mrpc':  # B
        AutoDataset = MRPCDataset
    elif dataset == 'agnews':  # M
        AutoDataset = AGNEWSDataset
    elif dataset == 'carer':  # M
        AutoDataset = CARERDataset
    elif dataset == 'sst5':  # M
        AutoDataset = SST5Dataset
    elif dataset == 'trec':  # M
        AutoDataset = TRECDataset
    elif dataset == 'webss':  # M
        AutoDataset = WebSSDataset
    else:
        raise NotImplementedError

    datadir = os.path.join(data_dir, dataset)
    train_data = AutoDataset(datadir, mode='train')
    dev_data = AutoDataset(datadir, mode='test')
    return train_data, dev_data


class BASEDataset:
    def __init__(
        self,
        data_dir,
        mode,
        is_jsonl=True
    ):
        """data key: sentence, label[0/1]"""
        super().__init__()
        if mode == 'dev':
            mode = 'dev_subsample'
        self.data = []
        # customize your own label map in inheritance
        self.id2label = {0: 'negative', 1: 'positive'}
        self.label2id = {'negative': 0, 'positive': 1}
        # by default, read from jsonl files
        if is_jsonl:
            data_file = os.path.join(data_dir, mode + '.jsonl')
            with open(data_file, 'r') as f:
                lines = f.readlines()
            for line in lines:
                instance = json.loads(line.strip())
                self.data.append(instance)

    def __len__(self):
        return len(self.data)

    def arrange(self):
        # aggregate data by each category
        data_by_cls = {}
        for i in range(self.__len__()):
            if self.label2id[self.data[i]['label']] not in data_by_cls:
                data_by_cls[self.label2id[self.data[i]['label']]] = []
            data_by_cls[self.label2id[self.data[i]['label']]].append(self.data[i])
        # extend data by each category
        data_inorder = []
        for cls in sorted(data_by_cls.keys()):
            data_inorder.extend(data_by_cls[cls])
        self.data = data_inorder

    def shuffle(self):
        random.shuffle(self.data)

    def subsamplebyshot(self, n_shot, seed=42, exclude=None):
        # exclude
        if exclude is not None:
            for ins in exclude:
                self.data.remove(ins)
        # aggregate data by each category
        random.seed(seed)
        data_by_cls = {}
        for i in range(self.__len__()):
            if self.label2id[self.data[i]['label']] not in data_by_cls:
                data_by_cls[self.label2id[self.data[i]['label']]] = []
            data_by_cls[self.label2id[self.data[i]['label']]].append(self.data[i])
        # evenly sample n examples from each category
        data_subsample = []
        for cls in data_by_cls.keys():
            data_subsampled_by_cls = random.sample(data_by_cls[cls], min(n_shot, len(data_by_cls[cls])))
            data_subsample.extend(data_subsampled_by_cls)
        random.shuffle(data_subsample)
        self.data = data_subsample


class SUBJDataset(BASEDataset):
    def __init__(
        self,
        data_dir,
        mode
    ):
        """data key: sentence, label[0/1]"""
        super().__init__(data_dir, mode)
        # subj only has test set
        self.label2id = {'0': 0, '1': 1}
        self.label2verb = {'0': 'subjective', '1': 'objective'}
        self.id2verb = ['subjective', 'objective']
        # self.arrange()
        self.shuffle()


class MRDataset(BASEDataset):
    def __init__(
        self,
        data_dir,
        mode
    ):
        """data key: sentence, label[0/1]"""
        super().__init__(data_dir, mode)
        self.label2id = {'0': 0, '1': 1}
        self.label2verb = {'0': 'negative', '1': 'positive'}
        self.id2verb = ['negative', 'positive']
        # self.arrange()
        self.shuffle()


class MRPCDataset(BASEDataset):
    def __init__(
        self,
        data_dir,
        mode
    ):
        """data key: sentence, label[0/1]"""
        super().__init__(data_dir, mode, is_jsonl=False)
        # ...
        self.label2id = {'0': 0, '1': 1}
        self.label2verb = {'0': 'not_equivalent', '1': 'equivalent'}
        self.id2verb = ['not_equivalent', 'equivalent']
        # xxx
        file_path = os.path.join(data_dir, 'msr_paraphrase_' + mode + '.txt')
        with open(file_path, 'r') as file:
            lines = file.readlines()
        for line in lines[1:]:
            label, _, _, sent1, sent2 = line.split('\t')
            instance = {'sentence_1': sent1, 'sentence_2': sent2, 'label': label}
            self.data.append(instance)
        # self.arrange()
        self.shuffle()


class AGNEWSDataset(BASEDataset):
    def __init__(
        self,
        data_dir,
        mode
    ):
        super().__init__(data_dir, mode)
        self.label2id = {'1': 0, '2': 1, '3': 2, '4': 3}
        self.label2verb = {'1': 'world', '2': 'sports', '3': 'business', '4': 'technology'}
        self.id2verb = ['world', 'sports', 'business', 'technology']
        # self.arrange()
        self.shuffle()


class CARERDataset(BASEDataset):
    def __init__(
        self,
        data_dir,
        mode
    ):
        if mode == 'dev':
            mode = 'validation'
        super().__init__(data_dir, mode, is_jsonl=False)
        self.label2id = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5}
        self.label2verb = {'0': 'sadness', '1': 'joy', '2': 'love', '3': 'anger','4': 'fear', '5': 'surprise'}
        self.id2verb = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
        # xxx
        data_file = os.path.join(data_dir, mode + '.jsonl')
        with open(data_file, 'r') as f:
            lines = f.readlines()
        for line in lines:
            instance = json.loads(line.strip())
            instance['label'] = str(instance['label'])
            instance['sentence'] = instance['text']
            del instance['text']
            self.data.append(instance)
        # self.arrange()
        self.shuffle()


class SST5Dataset(BASEDataset):
    def __init__(
        self,
        data_dir,
        mode
    ):
        """data key: sentence, label[0/1]"""
        super().__init__(data_dir, mode, is_jsonl=False)
        self.label2id = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4}
        # very positive, positive, neutral, negative, very negative
        self.label2verb = {'0': 'terrible', '1': 'bad', '2': 'okay', '3': 'good', '4': 'great'}
        self.id2verb = ['terrible', 'bad', 'okay', 'good', 'great']
        # xxx
        data_file = os.path.join(data_dir, mode + '.jsonl')
        with open(data_file, 'r') as f:
            lines = f.readlines()
        for line in lines:
            instance = json.loads(line.strip())
            instance['label'] = str(instance['label'])
            instance['sentence'] = instance['text']
            del instance['text']
            del instance['label_text']
            self.data.append(instance)
        # self.arrange()
        self.shuffle()


class TRECDataset(BASEDataset):
    def __init__(
        self,
        data_dir,
        mode
    ):
        super().__init__(data_dir, mode)
        self.label2id = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5}
        self.label2verb = {'0': 'description', '1': 'entity', '2': 'expression', '3': 'human','4': 'location', '5': 'number'}
        self.id2verb = ['description', 'entity', 'expression', 'human', 'location', 'number']
        # self.arrange()
        self.shuffle()


class WebSSDataset(BASEDataset):
    def __init__(
        self,
        data_dir,
        mode
    ):
        super().__init__(data_dir, mode, is_jsonl=False)
        # ...
        self.label2id = {'1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5, '7': 6, '8': 7}
        self.label2verb = {'1': 'business', '2': 'computers', '3': 'culture-arts-entertainment',
                           '4': 'education-science', '5': 'engineering', '6': 'health',
                           '7': 'politics-society', '8': 'sports'}
        self.id2verb = ['business', 'computers', 'culture-arts-entertainment',
                        'education-science', 'engineering', 'health', 'politics-society', 'sports']
        # xxx
        verb2label = dict()
        for key, value in self.label2verb.items():
            verb2label[value] = key
        # xxx
        file_path = os.path.join(data_dir, mode + '.txt')
        with open(file_path, 'r') as file:
            lines = file.readlines()
        for line in lines:
            line = line.split()
            label = verb2label[line[-1]]
            comment = ' '.join(line[:-1])
            instance = {'sentence': comment, 'label': label}
            self.data.append(instance)
        # self.arrange()
        self.shuffle()
