import os
import pickle
import logging
from random import shuffle, randrange, choices, sample
from typing import Optional

import torch
from torch.utils.data import Dataset

import numpy as np
import pandas as pd

from sklearn.preprocessing import LabelEncoder


class StandardScaler:
    def __init__(self):
        self.mean_ = None
        self.std_ = None
        
    def fit(self, X):
        self.mean_ = np.mean(X)
        self.std_ = np.std(X)
        
    def transform(self, X):
        return list((np.array(X)-self.mean_)/self.std_)

    def inverse_transform(self, X):
        if type(X) == list:
            return list(np.array(X)*self.std_+self.mean_)
        elif type(X) == np.ndarray:
            return np.array(X)*self.std_+self.mean_
        elif type(X) == torch.Tensor:
            return X*self.std_+self.mean_
        else:
            return list(np.array(X)*self.std_+self.mean_)


class ROIPretrainDatasetSplit(Dataset): # UKB Only

    def __init__(
        self,
        sourcedir, 
        dataset: Optional[str] = "ukb-rest", 
        roi: Optional[str] = "schaefer450", 
        split: Optional[str] = "train",
        temporal_resolution: Optional[float] = 0.735,
        initial_noise: Optional[int] = 10,
        dynamic_length: Optional[int] = 160,
        temporal_mask_ratio: Optional[float] = 0.75,
        spatial_mask_ratio: Optional[float] = 0.0,
        **kwargs
    ):
        super().__init__()

        self.normalize = kwargs.get("normalize", "robust")
        self.random_sampling = kwargs.get("random_sampling", False)
        self.initial_valid_timesteps = kwargs.get("initial_valid_timesteps", 1)
        self.ratio = kwargs.get("ratio", 1.0)
        
        self.temporal_mask_ratio = temporal_mask_ratio
        self.spatial_mask_ratio = spatial_mask_ratio
        self.temporal_resolution = temporal_resolution

        sourcedir = os.path.join(sourcedir, dataset)
        self.filename = dataset
        if roi == "schaefer450" :
            self.timeseries_filename = dataset + f'-roi_tian_450'
        else :
            self.timeseries_filename = dataset + f'_roi-{roi}'

        self.timeseries_filename = "-".join([self.timeseries_filename, split, "zeromean", self.normalize, "32bit", f"initnoise{initial_noise}"])
        assert os.path.isfile(os.path.join(sourcedir, f'{self.timeseries_filename}.pkl')), f'timeseries file {self.timeseries_filename}.pkl not found.'

        print(self.timeseries_filename)
        
        # prepare timeseries_dict and behavioral_df
        with open(os.path.join(sourcedir, f'{self.timeseries_filename}.pkl'), 'rb') as f:
            self.timeseries_dict = pickle.load(f)

        ##### set remaining dataset configurations
        timeseries_list = [timeseries.shape[0] for timeseries in self.timeseries_dict.values()]
        self.initial_noise = initial_noise
        self.min_timepoints = min(timeseries_list)

        self.feature_dim = list(self.timeseries_dict.values())[0].shape[1]
        self.dynamic_length = dynamic_length
        self.full_subject_list = list(self.timeseries_dict.keys())

        assert dynamic_length <= self.min_timepoints, f'dynamic length {dynamic_length} should be smaller than {self.min_timepoints}'

        logging.info(f'timeseries stat.: min {min(timeseries_list):.1f}, max {max(timeseries_list):.1f}, avg {np.mean(timeseries_list):.1f}')
        logging.info(f'number of subjects: {len(self.full_subject_list)}')

    def __len__(self):
        return len(self.full_subject_list)

    def spatiotemporal_masking(self, obs):
        """
        obs: 
          - raw ROI, shape [T, N_roi]
          - mapped, [T, n_networks, max_size]
        """
        device = obs.device

        T = obs.shape[0]
        temp_mask = torch.ones(T, dtype=torch.bool, device=device)
        if self.temporal_mask_ratio > 0:
            k = int(T * self.temporal_mask_ratio)
            idx = torch.randperm(T - self.initial_valid_timesteps, device=device)
            idx = idx + self.initial_valid_timesteps
            temp_mask[idx[:k]] = False

        if obs.ndim == 2: 
            F = obs.shape[1]
            feat_mask = torch.ones(F, dtype=torch.bool, device=device)
            if self.spatial_mask_ratio > 0:
                kf = int(F * self.spatial_mask_ratio)
                fidx = torch.randperm(F, device=device)[:kf]
                feat_mask[fidx] = False
            
        return obs, temp_mask, feat_mask

    def get_timeseries_data(self, timeseries, eps=1e-8) :
            
        obs_times = torch.arange(timeseries.shape[0]) * self.temporal_resolution
        
        if self.dynamic_length :
            if self.random_sampling:
                indices = sorted(sample(range(len(timeseries)), self.dynamic_length))
                timeseries = timeseries[indices]
                obs_times = obs_times[indices]
                obs_times -= obs_times[0].item()
            else:
                sampling_init = randrange(0, len(timeseries) - self.dynamic_length + 1)
                timeseries = timeseries[sampling_init:sampling_init + self.dynamic_length]                
                obs_times = torch.arange(timeseries.shape[0]) * self.temporal_resolution
        else:
            timeseries = timeseries[:self.min_timepoints]
            obs_times = torch.arange(timeseries.shape[0]) * self.temporal_resolution

        obs = torch.as_tensor(timeseries, dtype=torch.float32)
        return obs, obs_times
    
    def __getitem__(self, idx) :

        subject = self.full_subject_list[idx]
        obs = self.timeseries_dict[subject]

        obs, obs_times = self.get_timeseries_data(obs)        
        obs, temp_mask, feat_mask = self.spatiotemporal_masking(obs)

        return subject, obs, obs_times, temp_mask, feat_mask
