import os
import numpy as np
import pandas as pd
import random
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler
from typing import List
import warnings
import numpy as np
import pandas as pd
from pandas.tseries import offsets
from pandas.tseries.frequencies import to_offset
#from lib.augmentation import augment
import glob
import re
import torch

warnings.filterwarnings('ignore')


def run_augmentation_single(x, y, args):
    # print("Augmenting %s"%args.data)
    np.random.seed(args.seed)

    x_aug = x
    y_aug = y

    if len(x.shape) < 3:
        # Augmenting on the entire series: using the input data as "One Big Batch"
        #   Before  -   (sequence_length, num_channels)
        #   After   -   (1, sequence_length, num_channels)
        # Note: the 'sequence_length' here is actually the length of the entire series
        x_input = x[np.newaxis, :]
    elif len(x.shape) == 3:
        # Augmenting on the batch series: keep current dimension (batch_size, sequence_length, num_channels)
        x_input = x
    else:
        raise ValueError("Input must be (batch_size, sequence_length, num_channels) dimensional")

    augmentation_tags = args.extra_tag

    if (len(x.shape) < 3):
        # Reverse to two-dimensional in whole series augmentation scenario
        x_aug = x_aug.squeeze(0)
    return x_aug, y_aug, augmentation_tags


class TimeFeature:
    def __init__(self):
        pass

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        pass

    def __repr__(self):
        return self.__class__.__name__ + "()"


class SecondOfMinute(TimeFeature):
    """Minute of hour encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.second / 59.0 - 0.5


class MinuteOfHour(TimeFeature):
    """Minute of hour encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.minute / 59.0 - 0.5


class HourOfDay(TimeFeature):
    """Hour of day encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.hour / 23.0 - 0.5


class DayOfWeek(TimeFeature):
    """Hour of day encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.dayofweek / 6.0 - 0.5


class DayOfMonth(TimeFeature):
    """Day of month encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.day - 1) / 30.0 - 0.5


class DayOfYear(TimeFeature):
    """Day of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.dayofyear - 1) / 365.0 - 0.5


class MonthOfYear(TimeFeature):
    """Month of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.month - 1) / 11.0 - 0.5


class WeekOfYear(TimeFeature):
    """Week of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.isocalendar().week - 1) / 52.0 - 0.5


def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
    """
    Returns a list of time features that will be appropriate for the given frequency string.
    Parameters
    ----------
    freq_str
        Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
    """

    features_by_offsets = {
        offsets.YearEnd: [],
        offsets.QuarterEnd: [MonthOfYear],
        offsets.MonthEnd: [MonthOfYear],
        offsets.Week: [DayOfMonth, WeekOfYear],
        offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Minute: [
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
        offsets.Second: [
            SecondOfMinute,
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
    }

    offset = to_offset(freq_str)

    for offset_type, feature_classes in features_by_offsets.items():
        if isinstance(offset, offset_type):
            return [cls() for cls in feature_classes]

    supported_freq_msg = f"""
    Unsupported frequency {freq_str}
    The following frequencies are supported:
        Y   - yearly
            alias: A
        M   - monthly
        W   - weekly
        D   - daily
        B   - business days
        H   - hourly
        T   - minutely
            alias: min
        S   - secondly
    """
    raise RuntimeError(supported_freq_msg)


def time_features(dates, freq='h'):
    return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])


class Dataset_Custom(Dataset):
    """
    使用该类前请先设置随机种子，以保证每次随机采样（随机缺失）的结果一致，该设置通常在实验开始完成。
    '
    np.random.seed(seed) \\
    random.seed(seed)
    '
    """
    def __init__(self,  root_path, flag='train', size=None,
                 features='MS', data_path='BE.csv',
                 target='OT', scale=True, timeenc=0, freq='h', 
                 mask_covar_ratio=0, mask_target_ratio=0,
                 down_sample=0, block_mask_len=0):
        """
        Args:
            root_path: dataset folder path 
        """
        # size [seq_len, label_len, pred_len]
        # info
        if size == None:
            self.seq_len = 24 * 7
            self.label_len = 0
            self.pred_len = 24 
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]
        

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq


        self.block_mask_len = block_mask_len
        self.downsample = down_sample
        self.mask_covar_ratio = mask_covar_ratio
        self.mask_target_ratio = mask_target_ratio
        self.root_path = root_path
        self.data_path = data_path
        self.__read_data__()
        self.__init_windows__()

    @staticmethod
    def __mask__(shape, mask_ratio):
        if mask_ratio <= 0:
            return np.ones(shape, dtype=np.float32)
        mask = np.random.binomial(1, 1 - mask_ratio, size=shape).astype(np.float32)
        return mask
    
    
    @staticmethod
    def __downsmaple__(data:pd.DataFrame, stamp, mask_target_ratio, mask_covar_ratio, down_sample):
        # downsmaple and random mask
        l = len(data)
        n = data.shape[-1]
        indexs_mask = np.random.binomial(1, 1 - down_sample, size=l)
        #print("down sampled point：",l - sum(indexs_mask))
        data[indexs_mask==0] = 0.0
        mask_target = Dataset_Custom.__mask__(shape=(l, 1), mask_ratio=mask_target_ratio)
        mask_covar = Dataset_Custom.__mask__(shape=(l, n-1), mask_ratio=mask_covar_ratio)
        
        mask = np.concat([mask_covar, mask_target], axis=1)
        
        # 已经被丢弃的行对应的mask矩阵需要被标记为0，通常行缺失与随机缺失不会同时启用
        # 但若同时启用，需要保证这些行的掩码矩阵被正确设置
        
        mask[indexs_mask==0] = 0.0 
        
        return data, stamp, mask, indexs_mask
    


    def __read_data__(self):
        self.scaler = StandardScaler()
        print(os.path.join(self.root_path, self.data_path))

        df_raw = pd.read_csv(os.path.join(self.root_path,
                                          self.data_path))
        
        '''
        df_raw.columns: ['date', ...(other features), target feature]
        '''
        cols = list(df_raw.columns)
        cols.remove(self.target)
        cols.remove('date')
        df_raw = df_raw[['date'] + cols + [self.target]]
        num_train = int(len(df_raw) * 0.7)
        num_test = int(len(df_raw) * 0.2)
        num_vali = len(df_raw) - num_train - num_test
        border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
        border2s = [num_train, num_train + num_vali, len(df_raw)]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        df_stamp = df_raw[['date']][border1:border2]
        
        if self.timeenc == 0:
            df_stamp['date'] = pd.to_datetime(df_stamp.date)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            data_stamp = df_stamp.drop(labels=['date'], axis=1).values
         
        elif self.timeenc == 1:
            df_stamp['date'] = pd.to_datetime(df_stamp.date)
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)
       

        self.data_start = data[border1:border2]
        self.border1 =border1
        self.border2 = border2
        self.data_stamp_start = data_stamp

        # downsample and random mask
        data, data_stamp, mask, index_mask = Dataset_Custom.__downsmaple__(
                                                                                data[border1:border2], 
                                                                                data_stamp, 
                                                                                self.mask_target_ratio, 
                                                                                self.mask_covar_ratio,
                                                                                down_sample=self.downsample)

        # 
        self.data_x = data 
        self.data_y = data
        # self.data_x = data * mask
        # self.data_y = data * mask
        self.data_stamp = data_stamp
        self.mask = mask

        self.indexs_mask = index_mask

    def __init_windows__(self):
        total_length = self.border2 - self.border1
        self.window_count = total_length - self.seq_len - self.pred_len + 1
        
        # 存储每个窗口的起始和结束索引（基于原始数据）
        self.windows = []
        for i in range(self.window_count):
            s_begin = i
            s_end = s_begin + self.seq_len
            r_begin = s_end - self.label_len
            r_end = r_begin + self.label_len + self.pred_len
            self.windows.append((s_begin, s_end, r_begin, r_end)) 
        print(f"num samples: {len(self.windows)}")

    def __getitem__(self, index):
        """获取指定索引的窗口数据, block mask 也在此处处理"""
        # 从预定义的窗口中获取原始数据索引范围
        
        if self.block_mask_len>=1:
            s_begin, s_end, r_begin, r_end = self.windows[index]

            # 序列数据
            seq = self.data_x[s_begin:r_end].copy()
            
            # 时间戳
            seq_x_mark = self.data_stamp[s_begin:s_end]
            seq_y_mark = self.data_stamp[r_begin:r_end]
            
            # 掩码
            seq_mask_all = self.mask[s_begin:r_end].copy()
            
            # 为每个通道单独处理数据
            for i in range(seq.shape[-1]):
                
                rbm_index = random.randint(0, self.seq_len-self.block_mask_len)    
                rbm_en_index = rbm_index + self.block_mask_len
                seq[rbm_index:rbm_en_index, i] = 0
                seq_mask_all[rbm_index:rbm_en_index, i] = 0
                
                seq_x = seq[:self.seq_len, :]
                seq_y = seq[self.seq_len-self.label_len:, :]

                seq_mask = seq_mask_all[:self.seq_len, :]
                seq_y_mask = seq_mask_all[self.seq_len-self.label_len:, :]

        else:
            
            s_begin, s_end, r_begin, r_end = self.windows[index]
        

            # 序列数据 这里实际self.data_x和self.data_y指向同一DF对象
            seq_x = self.data_x[s_begin:s_end]   
            seq_y = self.data_y[r_begin:r_end]
            
            # 时间戳
            seq_x_mark = self.data_stamp[s_begin:s_end]
            seq_y_mark = self.data_stamp[r_begin:r_end]
            
            # 掩码
            seq_mask = self.mask[s_begin:s_end]
            seq_y_mask = self.mask[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark, seq_mask, seq_y_mask

    def __len__(self):
        # For test
        # return 645
        return len(self.data_x) - self.seq_len - self.pred_len + 1 

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)
