import re
import h5py
import torch
import joblib
from pathlib import Path
from torch.utils.data import DataLoader, TensorDataset
from transformers import PreTrainedTokenizer
from typing import List, Optional, Tuple
from collections import defaultdict

from .base import fMRIDataset
from ridge_utils import make_word_ds

class MothRadioHour(fMRIDataset):
    '''
    Moth Radio Hour dataset from:
    "Deniz, F., Nunez-Elizalde, A.O., Huth, A.G. and Gallant, J.L., 2019.
    The representation of semantic information across human cerebral cortex 
    during listening versus reading is invariant to stimulus modality. 
    Journal of Neuroscience, 39(39), pp.7722-7736."
    
    The dataset contains fMRI responses of 9 subjects listening/reading to 11
    narrative stories from the Moth Radio Hour podcast. Words were presented
    one-by-one for the same duration of the word in the spoken story.

    These data were collected during two 3 h scanning sessions that were
    performed on different days. Each story was presented during a separate
    fMRI scan, with TR=2 seconds. Each scan included 10 s (5 TR) of silence 
    both before and after the story.


    The preprocessed data can be downloaded
    `here <https://gin.g-node.org/denizenslab/narratives_reading_listening_fmri>`_.
    Every subject's fMRI recording is normalized across time for each voxel.
    Subjects are shown the stimuli word-by-word
    at intervals of 0.5 seconds. Text formatting such as italics or newline characters
    are displayed to the participants separately as `@` and `+`, respectively.
    '''

    dataset_name = "MothRadioHour"
    subject_idxs = ["01", "02", "03", "04", "05", "06", "07", "08", "09"]
    story_idx_to_name = {
        '01': "alternateithicatom",
        '02': "avatar",
        '03': "howtodraw",
        '04': "legacy",
        '05': "life",
        '06': "myfirstdaywiththeyankees",
        '07': "naked",
        '08': "odetostepfather",
        '09': "souls",
        '10': "undertheinfluence",
        '11': "wheretheressmoke"
    }

    def __init__(
        self,
        dataset_dir: str,
        context_length: int,
        tokenizer: PreTrainedTokenizer,
        device: str,
        remove_format_chars: bool = False,
        remove_punc_spacing: bool = False,
        verbose: bool = False
    ):
        """Initializes the dataset.

        Args:
            dataset_dir: Path to the downloaded data directory.
            context_length: For a given fMRI measurement, the number of previous tokens that we use
                to compute a given window.
            tokenizer: a HuggingFace tokenizer
            remove_format_chars: Whether or not to remove the special formatting
                characters that were displayed to participants such as `@` and `+`.
            remove_punc_spacing: Punctuation such as ellipses `...` or em-dashes
                `—` were displayed as `. . .` (period-by-period) and ` --- `,
                respectively, to participants. If this flag is true, punctuation is
                reformatted to what is conventional (i.e. `...` and `—` with no spaces
                around it).
        """
        self.context_length = context_length
        self.dataset_dir = Path(dataset_dir)
        self.words_dir = self.dataset_dir / "stimuli"
        self.grid_dir = self.words_dir / "textgrids/stimulus_textgrids"
        self.fmri_dir = self.dataset_dir / "responses"
        self.preprocessed_dir = self.dataset_dir / "preprocessed"
        self.remove_format_chars = remove_format_chars
        self.remove_punc_spacing = remove_punc_spacing
        self.tokenizer = tokenizer
        if tokenizer is not None:
            self.tokenizer.clean_up_tokenization_spaces = True
        self.verbose = verbose
        self.device = device

        # Load metadata
        if self.verbose:
            print("     1a - Load data from files.")
        self.load_data_from_files()

        # Extract contexts
        if self.verbose:
            print("     1b - Extract contexts.")
        self.extract_contexts()

    def load_fmri_data(self, fname, key=None):
        data = dict()
        with h5py.File(fname) as hf:
            if key is None:
                for k in hf.keys():
                    data[k] = hf[k][()]
            else:
                data[key] = hf[key][()]
        return data
    
    def load_data_from_files(self):
        # Load TR info
        grids = joblib.load(self.preprocessed_dir / "grids_huge.jbl") # Load TextGrids containing story annotations
        trfiles = joblib.load(self.preprocessed_dir / "trfiles_huge.jbl") # Load TRFiles containing TR information
        for story in list(grids):
            if story not in list(self.story_idx_to_name.values()):
                del grids[story]
                del trfiles[story]
        word_seqs = make_word_ds(grids, trfiles)
        self.words, self.tr_idx_to_num_words, self.tr_idx_to_words = defaultdict(list), defaultdict(list), defaultdict(list)
        for story_name in self.story_idx_to_name.values():
            self.words[story_name] = word_seqs[story_name].data
            self.tr_idx_to_num_words[story_name] = word_seqs[story_name].split_inds
            self.tr_idx_to_words[story_name] = word_seqs[story_name].chunks()
        # Load fMRI data for each subject
        self.subjects = []
        for subject_idx in self.subject_idxs:
            trndata = self.load_fmri_data(self.fmri_dir / f"subject{subject_idx}_reading_fmri_data_trn.hdf") # dict: "story_0x" -> (n_tr, n_voxels)
            # Remove extra dimension
            if subject_idx == "04" or subject_idx == "06" or subject_idx == "09":
                for key, val in trndata.items():
                    trndata[key] = val[0]
            valdata = self.load_fmri_data(self.fmri_dir / f"subject{subject_idx}_reading_fmri_data_val.hdf") # dict: "story_11" -> (2, n_tr, n_voxels)
            for key, val in valdata.items():
                if key not in trndata:
                    trndata[key] = val[0] # Use only the first recording
            self.subjects.append(trndata)
    
    def extract_contexts(self):
        self.contexts, self.tr_to_word_idxs, self.context_word_idxs = defaultdict(list), defaultdict(lambda: defaultdict(list)), defaultdict(list)
        for _, story_name in self.story_idx_to_name.items():
            num_seen_words, seen_words = 0, []
            
            for tr_idx, words_in_tr in enumerate(self.tr_idx_to_words[story_name]):
                words_in_tr = list(words_in_tr)
                num_words_in_tr = len(words_in_tr)
                
                # Handle empty TRs (including silence TRs at beginning/end)
                if num_words_in_tr == 0:
                    # If we have a preceding context, use it
                    if len(self.contexts[story_name]) > 0:
                        # Use the last context
                        self.tr_to_word_idxs[story_name][tr_idx].append(len(self.contexts[story_name]))
                        self.contexts[story_name].append(self.contexts[story_name][-1])
                        self.context_word_idxs[story_name].append(self.context_word_idxs[story_name][-1])
                    elif self.tokenizer is not None:
                        self.tr_to_word_idxs[story_name][tr_idx].append(len(self.contexts[story_name]))
                        self.contexts[story_name].append(self.tokenizer.pad_token * self.context_length)
                        self.context_word_idxs[story_name].append(self.context_word_idxs[story_name][-1] if len(self.context_word_idxs[story_name]) > 0 else (0, 0))
                else:
                    # Normal processing for TRs with words
                    words_in_tr_idxs = list(range(
                        self.tr_idx_to_num_words[story_name][tr_idx-1] if tr_idx > 0 else 0, 
                        self.tr_idx_to_num_words[story_name][tr_idx]
                    ))
                    
                    assert num_words_in_tr == len(words_in_tr_idxs), \
                        f"Warning: Number of words in TR {tr_idx} ({num_words_in_tr}) does not match the expected number ({len(words_in_tr_idxs)}) for story {story_name}."
                    
                    seen_words += words_in_tr
                    num_seen_words += num_words_in_tr
                    
                    #print(f"Story: {story_name}, TR: {tr_idx}, num words in tr {num_words_in_tr}, words {words_in_tr}, idxs {words_in_tr_idxs}")

                    for word_idx in words_in_tr_idxs:
                        word_idx_local = len(self.contexts[story_name])
                        start_idx = max(0, word_idx - self.context_length)
                        context_words = list(seen_words[start_idx:word_idx+1])
                        
                        self.tr_to_word_idxs[story_name][tr_idx].append(len(self.contexts[story_name]))
                        self.contexts[story_name].append(" ".join(context_words))
                        self.context_word_idxs[story_name].append((max(0, word_idx_local - self.context_length), word_idx_local))
                    #print(f"context word idxs {self.context_word_idxs[story_name][-num_words_in_tr:]}")

    def get_dataset_attributes(self):
        return self.subject_idxs, None, None

    def get_context_token_ids(self, batch_size: int, opt_contexts: Optional[List[List[str]]] = None) -> Tuple[DataLoader, List[dict]]:
        if opt_contexts is not None:
            contexts = opt_contexts

            toks = self.tokenizer.batch_encode_plus(contexts, return_offsets_mapping=True, return_attention_mask=True, padding=True)
            index_mapping = [
                self._get_word_idx_to_tok_idx_mapping(ctx, offsets)
                for ctx, offsets in zip(contexts, toks["offset_mapping"])
            ]
            token_idxs_to_avg = [word_idx_to_tok_idx[max(word_idx_to_tok_idx)] for word_idx_to_tok_idx in index_mapping]
            
            # Create dataloader
            dataset = TensorDataset(torch.tensor(toks["input_ids"]), torch.tensor(toks["attention_mask"]))
            data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

            return data_loader, token_idxs_to_avg, index_mapping
        else:
            contexts = self.contexts
        
            all_data_loaders, all_token_idxs_to_avg, all_index_mappings = {}, {}, {}
            for story_name, story_context in contexts.items():
                toks = self.tokenizer.batch_encode_plus(story_context, return_offsets_mapping=True, return_attention_mask=True, padding=True)
                index_mapping = [
                    self._get_word_idx_to_tok_idx_mapping(ctx, offsets)
                    for ctx, offsets in zip(story_context, toks["offset_mapping"])
                ]
                token_idxs_to_avg = [word_idx_to_tok_idx[max(word_idx_to_tok_idx)] for word_idx_to_tok_idx in index_mapping]
                
                # Create dataloader
                dataset = TensorDataset(torch.tensor(toks["input_ids"]), torch.tensor(toks["attention_mask"]))
                data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

                all_data_loaders[story_name] = data_loader
                all_token_idxs_to_avg[story_name] = token_idxs_to_avg
                all_index_mappings[story_name] = index_mapping

            return all_data_loaders, all_token_idxs_to_avg, all_index_mappings


    def _get_word_idx_to_tok_idx_mapping(self, text: str, offset_mapping: List[Tuple[int, int]]):
        word_indices = {}
        word_start_end = [(m.start(), m.end()) for m in re.finditer(r'\S+', text)]
        
        found_tok_idx = 0
        for word_idx, (start, end) in enumerate(word_start_end):
            token_indices = []
            for token_idx in range(found_tok_idx, len(offset_mapping)):
                token_start, token_end = offset_mapping[token_idx]
                if token_start == 0 and token_end == 0:
                    continue
                if token_start >= start-1 and token_end <= end:
                    token_indices.append(token_idx)
                    found_tok_idx = token_idx
                elif token_indices:
                    break
            word_indices[word_idx] = token_indices
        
        return word_indices