import numpy as np
from torch.utils.data import Dataset

'''
scale_data: scale data from [start_min, start_max] to [end_min, end_max]
'''
def scale_data(data, start_min, start_max, end_min, end_max):
    scaled_data = (data - start_min) / (start_max - start_min) # scale to [0,1]
    scaled_data = scaled_data * (end_max - end_min) + end_min # scale to [end_min, end_max]
    return scaled_data

'''
VFDataset: create dataset for visual field information
'''
class VF_Dataset(Dataset):
    def __init__(self, data_dict, representation='td', normalize=False, data_range=(0,50), transform=None):
        if representation not in ['hvf', 'td', 'aa']:
            raise NotImplementedError

        self.data = data_dict
        self.sample_ids = list(data_dict['data'].keys()) # determinism

        self.representation = representation
        self.normalize = normalize
        self.data_range = data_range
        self.transform = transform

    def __len__(self):
        return self.data['hvfs']
  
    def __getitem__(self, index):
        sid = self.sample_ids[index]
        baselines = np.array(self.data['data'][sid][f'baseline_{self.representation}'])
        future = np.array(self.data['data'][sid][f'followup_{self.representation}'])[np.newaxis,:]

        # for u-net, need to pad to even-numbered dimensions
        if self.representation in ['hvf', 'td']:
            baselines = np.pad(baselines, pad_width=((0,0),(0,1),(0,1)), mode='constant', constant_values=0)
            future = np.pad(future, pad_width=((0,0),(0,1),(0,1)), mode='constant', constant_values=0)

        # scale to [0,1]
        if self.normalize:
            baselines = scale_data(baselines,self.data_range[0],self.data_range[1],0,1)
            future = scale_data(future,self.data_range[0],self.data_range[1],0,1)

        ages = self.data['data'][sid]['baseline_age']
        horizon = self.data['data'][sid]['followup_age']
        start_age = ages[0] / 100
        age_deltas = [(a - start_age)/10 for a in ages] + [(horizon - start_age)/10]
        baseline_aa = np.array(self.data['data'][sid]['baseline_aa'])

        return (baselines, start_age, np.array(age_deltas), baseline_aa), future, sid
    
    def get_stats(self, sid):
        stats = {
            'baseline_md':self.data['data'][sid]['baseline_md'],
            'followup_md':self.data['data'][sid]['followup_md'],
            'baseline_td':np.array(self.data['data'][sid]['baseline_td']),
            'future_td':np.array(self.data['data'][sid]['followup_td']),
            'baseline_aa':self.data['data'][sid]['baseline_aa'],
            'followup_aa':self.data['data'][sid]['followup_aa'],
            'gender':self.data['data'][sid]['gender'],
            'side':self.data['data'][sid]['side'],
            'baseline_age':self.data['data'][sid]['baseline_age'],
            'followup_age':self.data['data'][sid]['followup_age'],
            'horizon':self.data['data'][sid]['followup_age']-self.data['data'][sid]['baseline_age'][-1]
        }
        return stats

    '''
    Return the mask for the data, where True indicates pixel values
    with meaning, and False indicates pixel values where padding is located
    '''
    def get_field_mask(self, include_blindspot=False):
        mask = ~np.array(self.data['padding_mask'])
        mask = np.pad(mask, pad_width=((0,1),(0,1)), mode='constant', constant_values=False)
        if include_blindspot: # by default, blindspot is considered part of the padding
            mask[3:5,7] = True
        return mask

    '''
    Return the mask for the padding, where False indicates pixel values
    with meaning, and True indicates pixel values where padding is located
    '''
    def get_padding_mask(self, exclude_blindspot=False):
        mask = np.array(self.data['padding_mask'])
        mask = np.pad(mask, pad_width=((0,1),(0,1)), mode='constant', constant_values=True)
        if exclude_blindspot: # by default, blindspot is considered part of the padding
            mask[3:5,7] = False
        return mask

'''
MoCap_Dataset: create dataset for motion capture information
'''
class MoCap_Dataset(Dataset):
    def __init__(self, data_dict, representation='kp_norm', normalize=False, data_range=(0,50), transform=None):
        if representation not in ['kp_raw', 'kp_norm', 'aa']:
            raise NotImplementedError

        self.data = data_dict
        self.sample_ids = list(data_dict['data'].keys()) # determinism

        self.representation = representation
        self.normalize = normalize
        self.data_range = data_range
        self.transform = transform

    def __len__(self):
        return self.data['samples']
  
    def __getitem__(self, index):
        sid = self.sample_ids[index]
        if f'history_{self.representation}' in self.data['data'][sid].keys():
            history = np.array(self.data['data'][sid][f'history_{self.representation}'], dtype=np.float32) # h x dims
            horizon = np.array(self.data['data'][sid][f'horizon_{self.representation}'], dtype=np.float32) # H x dims
        else:
            history = 0
            horizon = 0

        history_rt_sin = np.array(self.data['data'][sid][f'history_rt_sin'], dtype=np.float32) # H
        history_rt_cos = np.array(self.data['data'][sid][f'history_rt_cos'], dtype=np.float32) # H
        horizon_rt_sin = np.array(self.data['data'][sid][f'horizon_rt_sin'], dtype=np.float32) # H
        horizon_rt_cos = np.array(self.data['data'][sid][f'horizon_rt_cos'], dtype=np.float32) # H
        history_rt = np.stack((history_rt_sin, history_rt_cos), axis=1) # H x 2
        horizon_rt = np.stack((horizon_rt_sin, horizon_rt_cos), axis=1) # H x 2

        # scale to [0,1]
        if self.normalize:
            history = scale_data(history,self.data_range[0],self.data_range[1],0,1)
            horizon = scale_data(horizon,self.data_range[0],self.data_range[1],0,1)
            history_rt = scale_data(history_rt,-1,1,0,1)
            horizon_rt = scale_data(horizon_rt,-1,1,0,1)

        if f'history_aa' in self.data['data'][sid].keys():
            baseline_aa = np.array(self.data['data'][sid]['history_aa'], dtype=np.float32)
        else:
            baseline_aa = 0

        if self.representation == 'aa':
            return (history, baseline_aa), horizon, sid
        else:
            return (history, history_rt, baseline_aa), (horizon, horizon_rt), sid
    
    def get_stats(self, sid):
        stats = {
            'history_kp_raw':np.array(self.data['data'][sid]['history_kp_raw']),
            'horizon_kp_raw':np.array(self.data['data'][sid]['horizon_kp_raw']),
            'history_kp_norm':np.array(self.data['data'][sid]['history_kp_norm']), # centered and without y-axis rotation
            'horizon_kp_norm':np.array(self.data['data'][sid]['horizon_kp_norm']), # centered and without y-aaxis rotation
            'history_rt_cos':np.array(self.data['data'][sid]['history_rt_cos']),
            'history_rt_sin':np.array(self.data['data'][sid]['history_rt_sin']),
            'horizon_rt_cos':np.array(self.data['data'][sid]['horizon_rt_cos']),
            'horizon_rt_sin':np.array(self.data['data'][sid]['horizon_rt_sin']),
            }
        return stats