import torch
import numpy as np

from transformers.data.data_collator import DataCollatorMixin, _torch_collate_batch
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union, Tuple

rng = np.random.default_rng()

@dataclass
class DataCollatorForRoodPrediction(DataCollatorMixin):
    """
    Data collator used for mortality. Inputs are dynamically padded to the maximum length of a batch if they
    are not all of the same length.

    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        eol_threshold (`int`, *optional*):
            Threshold on the number of days left in a patient's life to consider for mortality prediction.
            E.g. `eol_threshold=30`: patients that died within a month of the visit are considered label 1,
            the remainder are label 0. If None, patients with any non-nan days left in life are labelled 1.
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
    """

    tokenizer: PreTrainedTokenizerBase
    eol_threshold: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    mlm: bool = True
    mlm_probability: float = 0.15
    pad_to_multiple_of: Optional[int] = None
    tf_experimental_compile: bool = False
    return_tensors: str = "pt"

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], Mapping):
            batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {
                "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
            }

        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        if self.mlm:
            batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
                batch["input_ids"], special_tokens_mask=special_tokens_mask
            )
        else:
            labels = batch["input_ids"].clone()
            if self.tokenizer.pad_token_id is not None:
                labels[labels == self.tokenizer.pad_token_id] = -100
            batch["labels"] = labels
        return batch

    def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        import torch

        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        rood_codes = [
            'G249-10', '30550-9', 'G520-10', '95901-9', '9916-9', 'F10129-10', '9100-9', '78097-9', 'R262-10', 'Z880-10', 'R4182-10', 'E8499-9', 'G20-10', '30500-9', '81600-9', 'F29-10', 'Z818-10', 'Z978-10', 'E9010-9', 'R471-10', '8020-9', '920-9', '87341-9', 'E9688-9', 'E8889-9', 'E887-9', 'E8498-9', 'H9222-10', 'Z681-10', 'R636-10', '78906-9', 'G248-10'
        ]
        mask = torch.isin(labels, torch.tensor(self.tokenizer.convert_tokens_to_ids(rood_codes)))
        probability_matrix.masked_fill_(mask, value=0.5)
        
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

@dataclass
class DataCollatorForMortalityPrediction(DataCollatorMixin):
    """
    Data collator used for mortality. Inputs are dynamically padded to the maximum length of a batch if they
    are not all of the same length.

    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        eol_threshold (`int`, *optional*):
            Threshold on the number of days left in a patient's life to consider for mortality prediction.
            E.g. `eol_threshold=30`: patients that died within a month of the visit are considered label 1,
            the remainder are label 0. If None, patients with any non-nan days left in life are labelled 1.
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
    """

    tokenizer: PreTrainedTokenizerBase
    eol_threshold: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __post_init__(self):
        self.return_tensors = 'pt'

    def torch_call(self, examples):
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], Mapping):
            batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {
                "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
            }

        batch = self.batch_data_collator(batch)
        return batch

    def batch_data_collator(self, batch):
        """
        Prepare binary labels for mortality prediction. Dynamically select a visit and remove all inputs after that visit.
        Based on the number of remaining days in the patient's life and the threshold, set the label.
        """
        new_batch = {}
        #print(batch['days_until_death'])
        for _, eol_days, input_ids, token_type_ids, attention_mask in zip(*batch.values()):
            # print(b)
            # input_ids = b['input_ids']
            # token_type_ids = b['token_type_ids']
            # attention_mask = b['attention_mask']
            # eol_days = b['days_until_death']
            #print(eol_days)
            num_visits = np.max(token_type_ids.numpy())
            if num_visits == 1:
                cutoff_visit = 1
            else:
                cutoff_visit = rng.integers(1, num_visits, endpoint=True)
            last_visit_ind = np.max(np.nonzero(token_type_ids.numpy() == cutoff_visit))
            cutoff_ind = last_visit_ind + 1
            label = int(eol_days[last_visit_ind-1] < self.eol_threshold)

            # Remove items after cutoff
            input_ids[cutoff_ind:] = self.tokenizer.pad_token_id
            input_ids[cutoff_ind] = self.tokenizer.sep_token_id  # Put a [SEP] token at the cutoff point

            token_type_ids[cutoff_ind:] = 0  # [SEP] and [PAD] have token_type_id of 0
            attention_mask[cutoff_ind + 1:] = 0  # After [SEP], we mask attention

            new_batch.setdefault("input_ids", []).append(input_ids.tolist())
            new_batch.setdefault("token_type_ids", []).append(token_type_ids.tolist())
            new_batch.setdefault("attention_mask", []).append(attention_mask.tolist())
            new_batch.setdefault("labels", []).append(label)
        
        new_batch = {k: torch.as_tensor(np.array(v)) for k, v in new_batch.items()}
        return new_batch  

@dataclass
class DataCollatorForDiseasePrediction(DataCollatorMixin):
    """
    Data collator. Accepts batches of ids for patients (input_ids, token_type_ids, attention_mask). Selects random visit number, masks tokens after (and including) that visit. 
    Produces ICD-10 label. Derivative of "DataCollatorForNextDiseasePrediction".

    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
    """
    
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __post_init__(self):
        self.return_tensors = 'pt'

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], Mapping):
            batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {
                "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
            }

        batch = self.batch_data_collator(batch)
        return batch

    def torch_mask_tokens(self, inputs: Any, input_types: Any, attention_mask: Any) -> Tuple[Any, Any, Any, Any]:
        """
        Prepare masked tokens inputs/labels for next disease prediction: MASK all of last visit. Set labelss to only have first disease code of last visit
        """
        import torch

        # get the first disease from the last visit and set labels to only include that one token
        next_disease_indices = torch.argmax(input_types, dim=1)
        labels = torch.stack([t[i:i+1] for i, t in zip(next_disease_indices, inputs)])

        # turn the next visit into just a mask token then a sep token
        post_next_disease_indices = torch.stack([torch.arange(0, inputs.shape[1]) >= i for i in next_disease_indices])
        inputs[post_next_disease_indices] = self.tokenizer.pad_token_id
        input_types[post_next_disease_indices] = self.tokenizer.pad_token_type_id
        attention_mask[post_next_disease_indices] = 0
        for i, j in enumerate(next_disease_indices):
            inputs[i, j] = self.tokenizer.sep_token_id
            attention_mask[i, j] = 1

        return inputs, input_types, attention_mask, labels
    
    def lookup_ICD_10_chapter(self, code):
        """
        Accepts ICD-9 or ICD-10 codes, returns ICD-10 chapter.
        """
        if code[-2:] == '-9':
            if code[0].isalpha():
                chapter = '20' 
                return chapter
            code = code[:-2]    # 2765-9 -> 2765
            if len(code) == 5:
                code = str(code)[:-2] + "." + str(code)[-2:]    # 27939 -> 279.39  
            if len(code) == 4:
                code = str(code)[:-1] + "." + str(code)[-1]     # 0032 -> 003.2  
            while code[0] == "0":  
                code = code[1:]  # 0032 -> 3.2
            code = float(code)
            code_dict = {'1': [1,139], '2': [140,239], '3': [240,279], '4': [280,289], '5': [290,319], '6': [320,359],'7': [360,379],
                        '8': [380,389],'9': [390,459], '10': [460,519], '11': [520,579], '12': [580,629], '13': [630,679], '14': [680,709],
                        '15': [710,739], '16': [740,759], '17': [760,779], '18': [780,799], '19': [800,999]}                                                                                                  
            for i in range(1, len(code_dict) + 1):
                if code_dict[str(i)][0] <= np.floor(code) <= code_dict[str(i)][1]:
                    chapter = str(i)
                    return chapter

        if code[-3:] == '-10':
            if code[:3] == 'O9A':
                chapter = '19'
                return chapter
            code = code[:3]
            code_dict = {'1': ['A00', 'B99'], '2': ['C00', 'D49'], '3': ['D50', 'D99'], '4': ['E00', 'E99'], '5': ['F00', 'F99'], '6': ['G00', 'G99'],'7': ['H00', 'H59'],
                                '8': ['H60', 'H99'],'9': ['I00', 'I99'], '10': ['J00', 'J99'], '11': ['K00', 'K99'], '12': ['L00', 'L99'], '13': ['M00', 'M99'], '14': ['N00', 'N99'], 
                                '15': ['O00', 'O99'], '16': ['P00', 'P99'], '17': ['Q00', 'Q99'], '18': ['R00', 'R99'], '19': ['S00', 'T99'], '20': ['V00', 'Y99'], '21': ['Z00', 'Z99'], 
                                '22': ['U00', 'U99']}  
            for i in range(1, len(code_dict)):
                if code_dict[str(i)][0] <= code <= code_dict[str(i)][1]:
                    chapter = str(i)
                    return chapter   

    def data_collator(self, batch, i):
        current = {'input_ids':batch['input_ids'][i], 'token_type_ids':batch['token_type_ids'][i], 'attention_mask':batch['attention_mask'][i]}
        max_token = max(current['token_type_ids'].tolist())
        if max_token == 1:
            batch['labels'].append(-100)
            return batch
        else:
            visit_to_not_keep = np.random.randint(2,max_token+1) 
            first_index_to_not_keep = current['token_type_ids'].tolist().index(visit_to_not_keep)
            label = self.tokenizer.convert_ids_to_tokens(current['input_ids'][first_index_to_not_keep].item())
            label = self.lookup_ICD_10_chapter(label)
            batch['labels'].append(int(label))
            current['input_ids'][first_index_to_not_keep] = self.tokenizer.convert_tokens_to_ids('[SEP]')
            current['input_ids'][first_index_to_not_keep+1:] = self.tokenizer.convert_tokens_to_ids('[PAD]')
            current['token_type_ids'][first_index_to_not_keep:] = 0
            current['attention_mask'][first_index_to_not_keep:] = 0
            batch['input_ids'][i] = current['input_ids']
            batch['token_type_ids'][i] = current['token_type_ids']
            batch['attention_mask'][i] = current['attention_mask']
            return batch   

    def batch_data_collator(self, batch):
        batch_size = len(batch['input_ids'])
        batch['labels'] = []
        for i in range(batch_size):
            batch = self.data_collator(batch, i)
        batch['labels'] = torch.LongTensor(batch['labels'])
        return batch 

@dataclass
class DataCollatorForMultiLabelsDiseasePrediction(DataCollatorMixin):
    """
    Data collator. Accepts batches of ids for patients (input_ids, token_type_ids, attention_mask). Selects random visit number, masks tokens after (and including) that visit. 
    Produces ICD-10 label. Derivative of "DataCollatorForNextDiseasePrediction".

    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
    """
    
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __post_init__(self):
        self.return_tensors = 'pt'

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], Mapping):
            batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {
                "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
            }

        batch = self.batch_data_collator(batch)
        return batch

    def torch_mask_tokens(self, inputs: Any, input_types: Any, attention_mask: Any) -> Tuple[Any, Any, Any, Any]:
        """
        Prepare masked tokens inputs/labels for next disease prediction: MASK all of last visit. Set labels to all of the disease codes of the last visit
        """
        import torch

        # get the first disease from the last visit and set labels to only include that one token
        next_disease_indices = torch.argmax(input_types, dim=1)
        labels = torch.stack([t[i:] for i, t in zip(next_disease_indices, inputs)])

        # turn the next visit into just a mask token then a sep token
        post_next_disease_indices = torch.stack([torch.arange(0, inputs.shape[1]) >= i for i in next_disease_indices])
        inputs[post_next_disease_indices] = self.tokenizer.pad_token_id
        input_types[post_next_disease_indices] = self.tokenizer.pad_token_type_id
        attention_mask[post_next_disease_indices] = 0
        for i, j in enumerate(next_disease_indices):
            inputs[i, j] = self.tokenizer.sep_token_id
            attention_mask[i, j] = 1

        return inputs, input_types, attention_mask, labels
    
    def get_chapter_multilabel(self, codes):
        multilabel = [0]*22
        for code in codes:
            if code == '[SEP]' or code == '[PAD]':
                break
            chapter = self.lookup_ICD_10_chapter(code)
            multilabel[chapter-1] = 1
        return multilabel
    
    def lookup_ICD_10_chapter(self, code):
        """
        Accepts ICD-9 or ICD-10 codes, returns ICD-10 chapter.
        """
        if code[-2:] == '-9':
            if code[0].isalpha():
                chapter = 20
                return chapter
            code = code[:-2]    # 2765-9 -> 2765
            if len(code) == 5:
                code = str(code)[:-2] + "." + str(code)[-2:]    # 27939 -> 279.39  
            if len(code) == 4:
                code = str(code)[:-1] + "." + str(code)[-1]     # 0032 -> 003.2  
            while code[0] == "0":  
                code = code[1:]  # 0032 -> 3.2
            code = float(code)
            code_dict = {'1': [1,139], '2': [140,239], '3': [240,279], '4': [280,289], '5': [290,319], '6': [320,359],'7': [360,379],
                        '8': [380,389],'9': [390,459], '10': [460,519], '11': [520,579], '12': [580,629], '13': [630,679], '14': [680,709],
                        '15': [710,739], '16': [740,759], '17': [760,779], '18': [780,799], '19': [800,999]}                                                                                                  
            for i in range(1, len(code_dict) + 1):
                if code_dict[str(i)][0] <= np.floor(code) <= code_dict[str(i)][1]:
                    chapter = i
                    return chapter

        if code[-3:] == '-10':
            if code[:3] == 'O9A':
                chapter = 19
                return chapter
            code = code[:3]
            code_dict = {'1': ['A00', 'B99'], '2': ['C00', 'D49'], '3': ['D50', 'D99'], '4': ['E00', 'E99'], '5': ['F00', 'F99'], '6': ['G00', 'G99'],'7': ['H00', 'H59'],
                                '8': ['H60', 'H99'],'9': ['I00', 'I99'], '10': ['J00', 'J99'], '11': ['K00', 'K99'], '12': ['L00', 'L99'], '13': ['M00', 'M99'], '14': ['N00', 'N99'], 
                                '15': ['O00', 'O99'], '16': ['P00', 'P99'], '17': ['Q00', 'Q99'], '18': ['R00', 'R99'], '19': ['S00', 'T99'], '20': ['V00', 'Y99'], '21': ['Z00', 'Z99'], 
                                '22': ['U00', 'U99']}  
            for i in range(1, len(code_dict) + 1):
                if code_dict[str(i)][0] <= code <= code_dict[str(i)][1]:
                    chapter = i
                    return chapter   

    def data_collator(self, batch, i):
        current = {'input_ids':batch['input_ids'][i], 'token_type_ids':batch['token_type_ids'][i], 'attention_mask':batch['attention_mask'][i]}
        max_token = max(current['token_type_ids'].tolist())
        #print(current['token_type_ids'].tolist())
        # if max_token == 1:
        #     batch['labels'].append(-100)
        #     return batch
        # else:
        # visit_to_not_keep = np.random.randint(1,max_token+1) 
        visit_to_not_keep = max_token
        first_index_to_not_keep = current['token_type_ids'].tolist().index(visit_to_not_keep)
        labels = self.tokenizer.convert_ids_to_tokens([i.item() for i in current['input_ids'][first_index_to_not_keep:]])
        label = self.get_chapter_multilabel(labels)
        batch['labels'].append(label)
        current['input_ids'][first_index_to_not_keep] = self.tokenizer.convert_tokens_to_ids('[SEP]')
        current['input_ids'][first_index_to_not_keep+1:] = self.tokenizer.convert_tokens_to_ids('[PAD]')
        current['token_type_ids'][first_index_to_not_keep:] = 0
        current['attention_mask'][first_index_to_not_keep:] = 0
        batch['input_ids'][i] = current['input_ids']
        batch['token_type_ids'][i] = current['token_type_ids']
        batch['attention_mask'][i] = current['attention_mask']
        return batch   

    def batch_data_collator(self, batch):
        batch_size = len(batch['input_ids'])
        batch['labels'] = []
        for i in range(batch_size):
            batch = self.data_collator(batch, i)
        batch['labels'] = torch.LongTensor(batch['labels'])
        return batch 
