import numpy as np
import torch
import mne
import json
import random

from collections import defaultdict
from itertools import chain
from typing import Tuple
from base.base_dataset import BaseDataset
from utils.ELM_utils import preprocess_report

class TUAB_H5(BaseDataset):
    def __init__(self, cfg, setting):
        super().__init__(cfg, setting)
            
    def _normalize(self, features):
        return (features - self.dataset_mean) / self.dataset_std

    def __getitem__(self, index):
        x = torch.from_numpy(self._normalize(self.features[index])).float()
        y = torch.from_numpy(np.array([self.labels[index]])).float()
        return x, y
    
    def __len__(self):
        return len(self.indices)      

class TUAB_H5_features(BaseDataset):
    def __init__(self, cfg, setting):
        super().__init__(cfg, setting)

    def _normalize(self, features):
        return (features - self.dataset_mean) / self.dataset_std
    
    def __getitem__(self, index):
        x = torch.from_numpy(self.features[index]).float()
        y = torch.from_numpy(np.array([self.labels[index]])).float()
        return x, y
    

class H5_MIL(BaseDataset):
    def __init__(self, cfg, setting):
        super().__init__(cfg, setting)
        self.clusters = {}
        self.clusters["history"] = ["CLINICAL HISTORY", "HISTORY"]
        self.clusters["medications"] = ["MEDICATIONS", "AED"]
        self.clusters["description"] = ["DESCRIPTION OF THE RECORD", "DESCRIPTION", "EEG BACKGROUND", "RANDOM WAKEFULNESS AND SLEEP", "EVENTS", "EVENT", "EPILEPTIFORM ACTIVITY", "OTHER PAROXYSMAL ACTIVITY (NON-EPILEPTIFORM)"]
        self.clusters["interpretation"] = ["IMPRESSION", "CLINICAL CORRELATION", "CORRELATION", "CONCLUSION", "SUMMARY OF FINDINGS", "SUMMARY", "DIAGNOSIS", "INTERPRETATION"] 
        self.cluster_to_id = {"history": 0, "medications": 1, "description": 2, "interpretation": 3}

        self.max_text_pairs = self.cfg["model"]["ELM"]["MIL_max_text_pairs"]
        self.max_eeg_pairs = self.cfg["model"]["ELM"]["MIL_max_eeg_pairs"]
        self.N, self.C, self.L = self.features.shape

        self.subject_indices  = self._group_by_subject()
        self.positive_sampling = self.cfg["model"]["ELM"]["MIL_positive_sampling"]
        
        self.custom_indices = np.load(cfg["dataset"]["path"] + "/indices/" + cfg["dataset"]["train_subsample"] + "_indices.npy")
        self.train_subject_indices = np.where(np.isin(self.subject_ids, self.custom_indices))[0]
        self.train_subject_ids = self.subject_ids[self.train_subject_indices]
        
        self.y_idx_map = {id: np.where(self.text_data_dict["subject_ids"] == id)[0].item() for id in np.unique(self.subject_ids)}
        
        self.sample_text = True
        self.nr = 1
        if "llm" in self.cfg["model"]["ELM"]["text_data_filename"]:
            print("LLM-generated reports detected!")
            print("Switching to simple text sampling mode.")
            self.simple = True
        elif "random" in self.cfg["model"]["ELM"]["text_data_filename"]:
            print("Random text detected.")
            print("Switching to simple text sampling mode.")
            self.simple = True
        else:
            self.simple = False
            
        self._collect_text_per_rec()
        self._generate_batches()

        self.current_epoch = -1
        self.batch_counter = 0
            
    def _group_by_subject(self):
        subject_indices = {}
        for i, subject in enumerate(self.subject_ids):
            if subject not in subject_indices:
                subject_indices[subject] = []
            subject_indices[subject].append(i)
        return subject_indices
    
    def _collect_text_per_rec(self):
        # map from subject id to all relevant sentences
        self.map_subject_text = {}
        
        for sub in np.unique(self.train_subject_ids):
            y = self.text_data_dict["raw_text"][self.y_idx_map[sub]]
            
            if isinstance(y, list):
                if len(y) == 1:
                    y = y[0]
                else: # In case multiple reports are found, randomly sample one
                    y = np.random.choice(y[:self.nr])
            assert isinstance(y, str)
            
            self.map_subject_text[sub], _ = preprocess_report(y, 
                        text_sample_mode=self.cfg["model"]["ELM"]["text_sample_mode"],
                        requested_headings=self.cfg["model"]["ELM"]["text_headings"],
                        sampling=False,
                        simple=self.simple,
                        prefix=self.cfg["model"]["ELM"]["text_prefix"])
        
    def toggle_sampling(self, sampling):
        self.sample_text = sampling

    def _normalize(self, features):
        return (features - self.dataset_mean) / self.dataset_std
    
    def _generate_y_given_x(self):
        x_ix, y_ix, id = [], [], []
        
        for subject_id in np.unique(self.train_subject_ids):
            for idx in self.subject_indices[subject_id]:
                x_ix.append([idx])
                id.append([subject_id]) 
                
                y_idx = np.arange(len(self.map_subject_text[subject_id]))
                if len(y_idx) > self.max_text_pairs:
                    y_idx = np.random.choice(y_idx, self.max_text_pairs, replace=False)
                    
                y_ix.append(list(y_idx))
        
        return x_ix, y_ix, id # lists of lists

    def _generate_x_given_y(self):
        x_ix, y_ix, id = [], [], []
        
        for subject_id in np.unique(self.train_subject_ids):
            for y_idx in enumerate(self.map_subject_text[subject_id]):
                y_ix.append([y_idx])
                id.append([subject_id])
                
                x_idx = self.subject_indices[subject_id]
                if len(x_idx) > self.max_eeg_pairs:
                    x_idx = np.random.choice(x_idx, self.max_eeg_pairs, replace=False)
                
                x_ix.append(list(x_idx))
                
        return x_ix, y_ix, id # lists of lists
    
    def _generate_x_and_y(self):
        x_ix, y_ix, id = [], [], []
        
        for subject_id in np.unique(self.train_subject_ids):
            eeg_idx = self.subject_indices[subject_id]
            text_idx = np.arange(len(self.map_subject_text[subject_id]))

            temp_text_pairs = 8
            
            num_eeg_sets = min(len(eeg_idx), self.max_eeg_pairs)
            num_text_sets = min(len(text_idx), self.max_text_pairs)
            
            num_sets = max(1, int(np.ceil(max(len(eeg_idx) / self.max_eeg_pairs, 
                                            len(text_idx) / temp_text_pairs))))
            
            for _ in range(num_sets):
                id.append([subject_id])
                
                if len(text_idx) > num_text_sets:
                    sampled_text_idx = np.random.choice(text_idx, num_text_sets, replace=False)
                else:
                    sampled_text_idx = text_idx
                y_ix.append(list(sampled_text_idx))
                                
                if len(eeg_idx) > num_eeg_sets:
                    sampled_eeg_idx = np.random.choice(eeg_idx, num_eeg_sets, replace=False)
                else:
                    sampled_eeg_idx = eeg_idx
                x_ix.append(list(sampled_eeg_idx))
        
        return x_ix, y_ix, id  # lists of lists
 
    def _generate_batches(self):
        if self.positive_sampling == "y|x":
            x_ix, y_ix, id = self._generate_y_given_x()
        elif self.positive_sampling == "x|y":
            x_ix, y_ix, id = self._generate_x_given_y()
        elif self.positive_sampling in ["x,y", "x,y_b"]:
            x_ix, y_ix, id = self._generate_x_and_y()
        else:
            raise ValueError(f"Unknown positive sampling strategy: {self.positive_sampling}")
        
        print("Amount of pairs: ", len(x_ix))
        
        # Group pairs by ID
        id_to_pairs = defaultdict(list)
        for i in range(len(id)):
            id_to_pairs[id[i][0]].append(i)

        for subject_id in id_to_pairs: # shuffle the order of within-subject pairs
            random.shuffle(id_to_pairs[subject_id])

        batch_size = self.cfg['training']['batch_size']
        ordered_pairs = []
        all_ids = list(id_to_pairs.keys())
        
        while id_to_pairs:
            batch = []
            used_ids = set()
            random.shuffle(all_ids)  # Shuffle IDs before each batch
            for current_id in all_ids:
                if len(batch) == batch_size:
                    break
                if current_id not in used_ids and id_to_pairs[current_id]:
                    pair_index = id_to_pairs[current_id].pop(0)
                    batch.append(pair_index)
                    used_ids.add(current_id)
                if not id_to_pairs[current_id]:
                    id_to_pairs.pop(current_id)
            
            if not batch:
                if not ordered_pairs:
                    print("Warning: Current setup doesn't allow for unique IDs in the first batch.")
                    # Fallback to original shuffling method
                    indices = list(range(len(x_ix)))
                    random.shuffle(indices)
                    self.current_epoch_pairs = ([x_ix[i] for i in indices], 
                                                [y_ix[i] for i in indices], 
                                                [id[i] for i in indices])
                    return
                break
            
            ordered_pairs.extend(batch)
            all_ids = list(id_to_pairs.keys())  # Update the list of available IDs
            
            if len(batch) < batch_size / 2:
                break

        # Use ordered_pairs to reorder x_ix, y_ix, and id while maintaining their relationships
        self.current_epoch_pairs = ([x_ix[i] for i in ordered_pairs], 
                                    [y_ix[i] for i in ordered_pairs], 
                                    [id[i] for i in ordered_pairs])
            
    def __len__(self):
        return len(self.current_epoch_pairs[0])
    
    def on_epoch_start(self, save_path):
        self._generate_batches()
        self.current_epoch += 1
        self.batch_counter = 0

    def __getitem__(self, index):

        # Retrieve the current epoch's pairs (EEG indices, text indices, and subject IDs)
        x_ix, y_ix, id = self.current_epoch_pairs
        current_id = id[index]

        sorted_indices = np.sort(x_ix[index]) # for .h5 access
        x = self._normalize(self.features[sorted_indices])
        
        # Retrieve relevant text samples for the current subject
        relevant_text = self.map_subject_text[current_id[0]]
        y = [relevant_text[i] for i in y_ix[index]]
        
        if self.positive_sampling == "y|x":
            x_id = current_id
            y_id = [current_id] * len(y)
        elif self.positive_sampling == "x|y":
            x_id = [current_id] * x.shape[0]
            y_id = current_id
        elif self.positive_sampling in ["x,y", "x,y_b"]:
            x_id = [current_id] * x.shape[0]
            y_id = [current_id] * len(y)
        
        return x, y, x_id, y_id

    def collate_fn(self, batch):
        eeg_crops, text_samples, eeg_id, text_id = zip(*batch)
        
        eeg_crops = list(chain.from_iterable(eeg_crops))
        eeg_id = list(chain.from_iterable(eeg_id))
        text_id = list(chain.from_iterable(text_id))
        text_samples = list(chain.from_iterable(text_samples))
        
        eeg_crops = torch.from_numpy(np.stack(eeg_crops)).float().squeeze()
        
        eeg_id = torch.from_numpy(np.array(eeg_id).squeeze())
        text_id = torch.from_numpy(np.array(text_id).squeeze())
        
        self.batch_counter += 1
        return eeg_crops, text_samples, eeg_id, text_id
        
    def save_text_index_map(self, path):
        serializable_map = {str(k): v for k, v in self.text_index_map.items()}
        with open(path, 'w') as f:
            json.dump(serializable_map, f)

    @staticmethod
    def load_text_index_map(load_path):
        with open(load_path, 'r') as f:
            return json.load(f)
    
class H5_ELM(BaseDataset):
    def __init__(self, cfg, setting):
        super().__init__(cfg, setting)
        self.y_idx_map = {id: np.where(self.text_data_dict["subject_ids"] == id)[0].item() for id in np.unique(self.subject_ids)}
        self.sample_text = True
                
        if "llm" in self.cfg["model"]["ELM"]["text_data_filename"]:
            print("Detected LLM-generated Reports.")
            print("Switching to simple text sampling mode. (i.e. unprocessed, full reports)")
            self.simple = True
        else:
            self.simple = False
        
    def toggle_sampling(self, sampling):
        self.sample_text = sampling

    def _normalize(self, features) -> np.ndarray:
        return (features - self.dataset_mean) / self.dataset_std
    
    def __getitem__(self, index) -> Tuple[torch.Tensor, np.ndarray]:
        x = torch.from_numpy(self._normalize(self.features[index])).float()

        sub_id = self.subject_ids[index]
        y_idx = self.y_idx_map.get(sub_id)
        if y_idx is None:
            raise ValueError(f"No matching text data found for subject id {sub_id}")
        
        y = self.text_data_dict["raw_text"][y_idx]
        y, _ = preprocess_report(y, 
                              text_sample_mode=self.cfg["model"]["ELM"]["text_sample_mode"],
                              requested_headings=self.cfg["model"]["ELM"]["text_headings"],
                              sampling=self.sample_text,
                              simple=self.simple,
                              prefix=self.cfg["model"]["ELM"]["text_prefix"])

        return x, y 
    
    def collate_fn(self, batch):

        x, y = zip(*batch)
        x = torch.stack(x)
        y = list(y)

        return x, y

    def __len__(self) -> int:
        return len(self.indices)   

def filter_data(x: torch.Tensor, freq: int=200, bands: list=[(1,7), (8, 30), (31, 49)], axis: int=1,
                norm_stats: list=[0.2, 0.085, 0.045]):
    # Unoptimized band-pass filtering. Don't use this for more than an experiment.
    f = [mne.filter.filter_data(x.astype(np.float64), freq, band[0], band[1], verbose="critical", n_jobs=1)/norm_stats[i] for i, band in enumerate(bands)]
    return np.stack(f, axis=axis)