import re
import torch
import random
import numpy as np
from pathlib import Path
import scipy.io
from torch.utils.data import DataLoader, TensorDataset
from transformers import PreTrainedTokenizer
from typing import List, Optional, Tuple
from collections import defaultdict

from .base import fMRIDataset

class HarryPotter(fMRIDataset):
    '''
    Harry Potter dataset from *Simultaneously Uncovering the Patterns of
    Brain Regions Involved in Different Story
    Reading Subprocesses* by Wehbe et al. (2014). The dataset contains eight
    subjects and each subject is read Chapter 9 from *Harry
    Potter and the Sorcerer's Stone*.

    The preprocessed data can be downloaded from
    `here <https://drive.google.com/drive/folders/1Q6zVCAJtKuLOh-zWpkS3lH8LBvHcEOE8>`_.
    Every subject's fMRI recording is normalized across time for each voxel.
    Subjects are shown the stimuli word-by-word
    at intervals of 0.5 seconds. Text formatting such as italics or newline characters
    are displayed to the participants separately as `@` and `+`, respectively.
    '''

    dataset_name = "HarryPotter"
    subject_idxs = ["F", "H", "I", "J", "K", "L", "M", "N"]
    roi_names = [
        "PostTemp",
        "AntTemp",
        "AngularG",
        "IFG",
        "MFG",
        "IFGorb",
        "pCingulate",
        "dmpfc",
    ]
    perturbation_types = ["sliding_window", "last_only", "incremental"]
    num_words_per_tr = 4
    feature_categories = {"semantic": ["nnse"],
                          "syntactic": ["part_of_speech", "dependency_role"],
                          "discourse": ["emotion", "motion", "speech", "characters", "verbs"],
                          "visual": ["visual"],}

    def __init__(
        self,
        dataset_dir: str,
        context_length: int,
        tokenizer: PreTrainedTokenizer,
        device: str,
        remove_format_chars: bool = False,
        remove_punc_spacing: bool = False,
        verbose: bool = False
    ):
        """Initializes the dataset.

        Args:
            dataset_dir: Path to the downloaded data directory. It is assumed that the
                subdirectory structure and file naming follows
                `here <https://drive.google.com/drive/folders/1Q6zVCAJtKuLOh-zWpkS3lH8LBvHcEOE8>`__.
            context_length: For a given fMRI measurement, the number of previous tokens that we use
                to compute a given window.
            tokenizer: a HuggingFace tokenizer
            remove_format_chars: Whether or not to remove the special formatting
                characters that were displayed to participants such as `@` and `+`.
            remove_punc_spacing: Punctuation such as ellipses `...` or em-dashes
                `—` were displayed as `. . .` (period-by-period) and ` --- `,
                respectively, to participants. If this flag is true, punctuation is
                reformatted to what is conventional (i.e. `...` and `—` with no spaces
                around it).
        """
        self.context_length = context_length
        self.dataset_dir = Path(dataset_dir)
        self.fmri_dir = self.dataset_dir / "fMRI"
        self.voxel_n = self.dataset_dir / "voxel_neighborhoods"
        self.rois = self.fmri_dir / "HP_subj_roi_inds.npy"
        self.story_features = self.dataset_dir / "story_features.mat"
        self.remove_format_chars = remove_format_chars
        self.remove_punc_spacing = remove_punc_spacing
        self.tokenizer = tokenizer
        if tokenizer is not None:
            self.tokenizer.clean_up_tokenization_spaces = True
        self.verbose = verbose
        self.device = device

        # Load metadata
        if self.verbose:
            print("     1a - Load data from files.")
        self.load_data_from_files()        

        # Extract contexts
        if self.verbose:
            print("     1b - Extract contexts.")
        self.extract_contexts()
    
    def load_data_from_files(self):
        self.words = np.load(self.fmri_dir / "words_fmri.npy")
        self.word_timing = np.load(self.fmri_dir / "time_words_fmri.npy")
        self.fmri_timing = np.load(self.fmri_dir / "time_fmri.npy")
        self.subjects = [np.load(self.fmri_dir / f"data_subject_{i}.npy") for i in self.subject_idxs]
        self.subject_rois = [np.load(self.rois, allow_pickle=True).item()[i] for i in self.subject_idxs]
        runs = np.load(self.fmri_dir / "runs_fmri.npy")

        # Remove the edges of each run
        self.fmri_timing = np.concatenate(
            [self.fmri_timing[runs == i][20:-15] for i in range(1, 5)]
        )
        self.runs_cropped = np.concatenate(
            [runs[runs == i][20:-15] for i in range(1, 5)]
        )
    
    def get_dataset_attributes(self):
        return self.subject_idxs, self.roi_names, self.subject_rois
    
    def get_story_features(self) -> List[Tuple[str, List[str], np.array]]:
        story_data = scipy.io.loadmat(self.story_features)['features'][0]
        discourse_features = []
        story_features = defaultdict(dict)

        for feat_name, subcategories, flags in story_data:
            feat_name = feat_name[0]
            subcategories = [subcat[0] for subcat in subcategories[0]]

            # If it is a discourse feature, we remove sticky features and store the corresponding flag
            if feat_name.lower() in self.feature_categories['discourse']:
                sticky_idxs = [i for i, sub_name in enumerate(subcategories) if 'sticky' in sub_name]
                subcategories = [sub_name for i, sub_name in enumerate(subcategories) if i not in sticky_idxs]
                flags = np.delete(flags, sticky_idxs, axis=1)
                discourse_features.append((feat_name, subcategories, flags))
                story_features['discourse'] = {"feature_names": subcategories, "feature_tags": flags}
            # If it is a semantic feature, we store the corresponding flag
            # and the feature name
            elif feat_name.lower() in self.feature_categories['semantic']:
                story_features['semantic'] = {"feature_names": subcategories, "feature_tags": flags}
            # If it is a syntactic feature, we store the corresponding flag
            # and the feature name
            elif feat_name.lower() in self.feature_categories['syntactic']:
                story_features['syntactic'] = {"feature_names": subcategories, "feature_tags": flags}
            # If it is a visual feature, we store the corresponding flag
            # and the feature name
            elif feat_name.lower() in self.feature_categories['visual']:
                story_features['visual'] = {"feature_names": subcategories, "feature_tags": flags}

        return story_features, discourse_features
    
    def extract_contexts(self):        
        self.contexts, self.tr_to_word_idxs, self.context_word_idxs = [], {}, []
        for i, mri_time in enumerate(self.fmri_timing):
            f = filter(lambda x: x[0] <= mri_time and (x[0] > self.fmri_timing[i-1] if i > 0 else x[0] > mri_time - 2.0), zip(self.word_timing, range(len(self.word_timing))))
            words_in_tr_idxs = list(map(lambda x: x[1], f))
            # Handle gap between runs due to removed TRs (20 at the beginning and 15 at the end):
            # There are 3 occasions in which the number of words between two subsequent TRs is 87
            # (this happens after 305, 622, and 866 TRs). We keep the last 4 words, that are the
            # ones belonging to the current TR.
            if len(words_in_tr_idxs) > 4:
                words_in_tr_idxs = words_in_tr_idxs[-4:]

            # For each word in the TR, get the context_length context having such word as last word.
            self.tr_to_word_idxs[i] = []
            for word_idx in words_in_tr_idxs:
                start_idx = word_idx-self.context_length if word_idx-self.context_length >= 0 else 0
                context_words = list(self.words[start_idx:word_idx])
                self.tr_to_word_idxs[i].append(len(self.contexts))
                self.contexts.append(self.preprocess_context(" ".join(context_words)))
                self.context_word_idxs.append((start_idx, word_idx))
            if len(words_in_tr_idxs) != 4:
                print(f"Warning: TR with less than 4 words {len(context_words), i, mri_time, words_in_tr_idxs}")
    
    def preprocess_context(self, t_proc):
        remove_format_chars = lambda x: re.sub(r"(?<=\S)[@](?=\S)", " ", x)
        unify_em_dash = lambda x: re.sub(r"--", "—", x)
        
        if self.remove_punc_spacing:
            t_proc = unify_em_dash(t_proc)
        if self.remove_format_chars:
            t_proc = remove_format_chars(t_proc)
        return t_proc

    def get_context_token_ids(self, batch_size: int, opt_contexts: Optional[List[List[str]]] = None) -> Tuple[DataLoader, List[dict]]:
        if opt_contexts is not None:
            contexts = opt_contexts
        else:
            contexts = self.contexts
        toks = self.tokenizer.batch_encode_plus(contexts, return_offsets_mapping=True, return_attention_mask=True, padding=True)
        index_mapping = [
            self._get_word_idx_to_tok_idx_mapping(ctx, offsets)
            for ctx, offsets in zip(contexts, toks["offset_mapping"])
        ]
        token_idxs_to_avg = [word_idx_to_tok_idx[max(word_idx_to_tok_idx)] for word_idx_to_tok_idx in index_mapping]
        
        # Create dataloader
        dataset = TensorDataset(torch.tensor(toks["input_ids"]), torch.tensor(toks["attention_mask"]))
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        return data_loader, token_idxs_to_avg, index_mapping


    def _get_word_idx_to_tok_idx_mapping(self, text: str, offset_mapping: List[Tuple[int, int]]):
        word_indices = {}
        word_start_end = [(m.start(), m.end()) for m in re.finditer(r'\S+', text)]
        
        found_tok_idx = 0
        for word_idx, (start, end) in enumerate(word_start_end):
            token_indices = []
            for token_idx in range(found_tok_idx, len(offset_mapping)):
                token_start, token_end = offset_mapping[token_idx]
                if token_start == 0 and token_end == 0:
                    continue
                if token_start >= start-1 and token_end <= end:
                    token_indices.append(token_idx)
                    found_tok_idx = token_idx
                elif token_indices:
                    break
            word_indices[word_idx] = token_indices
        
        return word_indices