import re
import random

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from nltk.tokenize import RegexpTokenizer


from file_io import Files


def reformat(code, is_diag):
    code = ''.join(code.split('.'))
    if is_diag:
        if code.startswith('E'):
            if len(code) > 4:
                code = code[:4] + '.' + code[4:]
        else:
            if len(code) > 3:
                code = code[:3] + '.' + code[3:]
    else:
        code = code[:2] + '.' + code[2:]
    return code


def load_code_descriptions():
    desc_dict = {}
    diagnosis_file = Files.mimic_code_descriptions['diagnosis'].load()
    procedure_file = Files.mimic_code_descriptions['procedure'].load()
    icd9_descriptions_file = Files.icd9_descriptions.load()

    for _, row in diagnosis_file.iterrows():
        code, desc = row['ICD9_CODE'], row['LONG_TITLE']
        desc_dict[reformat(code, True)] = desc
    for _, row in procedure_file.iterrows():
        code, desc = row['ICD9_CODE'], row['LONG_TITLE']
        desc_dict[reformat(code, False)] = desc
        if code not in desc_dict:
            desc_dict[reformat(code, False)] = desc
    for line in icd9_descriptions_file:
        line = line.rstrip().split()
        code = line[0]
        if code not in desc_dict.keys():
            desc_dict[code] = ' '.join(line[1:])
    return desc_dict


def load_full_codes(version):
    desc_dict = load_code_descriptions()
    codes = set()
    for type_ in ['train', 'dev', 'test']:
        label_csv = Files.label_csvs[version][type_].load()
        for _, row in label_csv.iterrows():
            for code in row['LABELS'].split(';'):
                if code in desc_dict:
                    codes.add(code)
    codes = set([c for c in codes if c != ''])
    ind2c = {i: c for i, c in enumerate(sorted(codes))}
    c2ind = {c: i for i, c in ind2c.items()}
    return ind2c, c2ind, desc_dict


def load_vocab():
    model = Files.word2vec_model.load()
    words = list(model.wv.key_to_index)
    del model

    word_count_dict = Files.word_count_dict.load()
    words = [w for w in words if w in word_count_dict]

    for w in ['**UNK**', '**PAD**', '**MASK**']:
        if w not in words:
            words = words + [w]
    word2id = {word: idx for idx, word in enumerate(words)}
    id2word = {idx: word for idx, word in enumerate(words)}
    return word2id, id2word


def load_embeddings():
    word2vec_model = Files.word2vec_model.load()
    words = list(word2vec_model.wv.key_to_index)

    original_word_count = len(words)

    word_count_dict = Files.word_count_dict.load()
    words = [w for w in words if w in word_count_dict]

    for w in ['**UNK**', '**PAD**', '**MASK**']:
        if w not in words:
            words = words + [w]
    id2word = {idx: word for idx, word in enumerate(words)}

    new_weights = []
    for i in range(len(id2word)):
        if not id2word[i] in ['**UNK**', '**PAD**', '**MASK**']:
            new_weights.append(word2vec_model.wv[id2word[i]])
        elif id2word[i] == '**UNK**':
            print('adding unk embedding')
            new_weights.append(np.random.randn(len(new_weights[-1])))
        elif id2word[i] == '**MASK**':
            print('adding mask embedding')
            new_weights.append(np.random.randn(len(new_weights[-1])))
        elif id2word[i] == '**PAD**':
            print('adding pad embedding')
            new_weights.append(np.zeros_like(new_weights[-1]))
    new_weights = np.array(new_weights)
    print(f'Word count: {len(id2word)}')
    print(f'Load embedding count: {len(new_weights)}')
    print(
        f'Original word count: {original_word_count}/{len(word_count_dict)}')
    del word2vec_model
    return new_weights


class KFold:
    def __init__(self, train_section_dataset, dev_section_dataset, train_dataset, dev_dataset):
        self.section_dataset = train_section_dataset + dev_section_dataset
        self.dataset = train_dataset + dev_dataset
        self.train_length = len(train_dataset)
        self.indices = np.arange(len(self.section_dataset))

    def choice(self):
        np.random.shuffle(self.indices)
        train_indices = self.indices[:self.train_length]
        dev_indices = self.indices[self.train_length:]
        train_dataset = [self.section_dataset[i] for i in train_indices]
        dev_dataset = [self.dataset[i] for i in dev_indices]
        return train_dataset, dev_dataset


class MimicFullDataset(Dataset):
    def __init__(self, config, mode, length):
        self.mode = mode
        self.version = config.version

        self.dataset = None

        self.word2id, self.id2word = load_vocab()

        self.truncate_length = config.truncate_length

        self.section_titles = list(Files.sections.load().keys())

        self.ind2c, self.c2ind, self.desc_dict = load_full_codes(self.version)
        self.code_count = len(self.ind2c)
        if mode == 'train':
            print(f'Code count: {self.code_count}')

        self.tokenizer = RegexpTokenizer(r'\w+')

        self.len = length

        self.label_truncate_length = config.label_truncate_length
        self.term_count = config.term_count
        if self.mode == 'train':
            self.c_input_ids, self.c_word_mask = self.prepare_label_feature()

        self.label_synonyms = self.extract_label_synonyms()

    def set_dataset(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return self.len

    def get_text(self, index):
        if self.mode == 'train':
            sections = self.dataset[index]['sections']
            kept_sections = []
            for title in self.section_titles:
                content = sections[title]
                if len(content) > 0:
                    rate = random.random()
                    if rate >= 0.2:
                        kept_sections.append(content)
            if len(kept_sections) == 0:
                text = ' '.join(sections[title] for title in self.section_titles)
            else:
                text = ' '.join(kept_sections)
        else:
            text = self.dataset[index]['TEXT']
        return text

    def split(self, text):
        sp = re.sub(r'\n\n+|  +', '\t', text.strip())\
            .replace('\n', ' ')\
            .replace('!', '\t')\
            .replace('?', '\t')\
            .replace('.', '\t')
        return [s.strip() for s in sp.split('\t') if s.strip()]

    def tokenize(self, text):
        words = []
        texts = self.split(text)
        for text in texts:
            text_words = [w.lower() for w in self.tokenizer.tokenize(text) if not w.isnumeric()]
            if len(text_words) > 0:
                words.extend(text_words)
        return words

    def get_text_label(self, index):
        text = self.get_text(index)
        if self.mode == 'train':
            labels = str(self.dataset[index]['labels']).split(';')
        else:
            labels = str(self.dataset[index]['LABELS']).split(';')
        return text, labels

    def pad(self, input_ids, word_mask, truncate_length):
        if len(input_ids) >= truncate_length:
            input_ids = input_ids[:truncate_length]
            word_mask = word_mask[:truncate_length]
            return input_ids, word_mask

        pad_token_id = self.word2id['**PAD**']
        input_ids = input_ids + [pad_token_id] * (truncate_length - len(input_ids))
        word_mask = word_mask + [0] * (truncate_length - len(word_mask))
        return input_ids, word_mask

    def text2feature(self, text):
        words = self.tokenize(text)
        input_ids = [self.word2id.get(word, self.word2id['**UNK**']) for word in words]
        word_mask = [1] * len(input_ids)
        return input_ids, word_mask

    def process(self, text, labels):
        input_ids, word_mask = self.text2feature(text)

        binary_label = [0] * self.code_count
        for label in labels:
            if label in self.c2ind:
                binary_label[self.c2ind[label]] = 1

        return input_ids, word_mask, binary_label

    def __getitem__(self, index):
        text, labels = self.get_text_label(index)
        input_ids, word_mask, binary_label = self.process(text, labels)
        return input_ids, word_mask, binary_label

    def extract_label_synonyms(self):
        code_synonyms = Files.code_synonyms.load()
        label_synonyms = {}
        for code in self.c2ind:
            label_synonyms[code] = [self.desc_dict[code]] + code_synonyms[code]
        return label_synonyms

    def extract_label_desc(self):
        desc_list = []
        for i in self.ind2c:
            code = self.ind2c[i]
            if code not in self.desc_dict:
                print(f'Not find desc of {code}')
            desc = self.desc_dict.get(code, code)
            desc_list.append(desc)
        return desc_list

    def prepare_label_feature(self):
        print('Prepare Label Feature')
        desc_list = self.extract_label_desc()
        if self.term_count == 1:
            c_desc_list = desc_list
        else:
            c_desc_list = []
            code_synonyms = Files.code_synonyms.load()
            for i in self.ind2c:
                code = self.ind2c[i]
                tmp_desc = [desc_list[i]]
                new_terms = code_synonyms.get(code, [])
                if len(new_terms) >= self.term_count - 1:
                    tmp_desc.extend(new_terms[0:self.term_count - 1])
                else:
                    tmp_desc.extend(new_terms)
                    repeat_count = int(self.term_count / len(tmp_desc)) + 1
                    tmp_desc = (tmp_desc * repeat_count)[0:self.term_count]
                if i < 5:
                    print(code, tmp_desc)
                c_desc_list.extend(tmp_desc)

        c_input_ids = []
        c_word_mask = []

        for i, desc in enumerate(c_desc_list):
            input_ids, word_mask = self.text2feature(desc)
            input_ids, word_mask = self.pad(input_ids, word_mask, truncate_length=self.label_truncate_length)
            c_input_ids.append(input_ids)
            c_word_mask.append(word_mask)

        return c_input_ids, c_word_mask


class DataCollator:
    def __init__(self, train_dataset):
        self.train_dataset = train_dataset

    def __call__(self, batch):
        batch_input_ids, batch_word_mask, batch_labels = [], [], []
        max_len = min(max(len(input_ids) for input_ids, _, _ in batch), self.train_dataset.truncate_length)
        for input_ids, word_mask, labels in batch:
            input_ids, word_mask = self.train_dataset.pad(input_ids, word_mask, max_len)
            batch_input_ids.append(input_ids)
            batch_word_mask.append(word_mask)
            batch_labels.append(labels)
        batch_input_ids = torch.LongTensor(batch_input_ids)
        batch_word_mask = torch.FloatTensor(batch_word_mask)
        batch_labels = torch.LongTensor(batch_labels)
        return {
            'input_ids': batch_input_ids,
            'word_mask': batch_word_mask,
            'labels': batch_labels
        }


def build_dataloader(config, dataset, train_dataset):
    collate_fn = DataCollator(train_dataset)
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=config.batch_size,
        shuffle=config.shuffle,
        num_workers=config.num_workers,
        collate_fn=collate_fn
    )
    return dataloader
