import numpy as np
import pandas as pd




def preprocess_ETTm(ori_df, task_type):
    assert task_type in ["SSR", "ASR"]


    # normalizing
    data_df = ori_df.drop(columns=["date"])
    data_df = (data_df - data_df.mean()) / data_df.std()


    if task_type == "SSR" :
        target_mask = [1 if i.split(":")[1] == "00" else 0 for i in ori_df["date"].values]
        mask_index = [i for i in range(len(target_mask)) if target_mask[i] == 1]
        lr_df = data_df.iloc[mask_index, :]
        lr_df = lr_df.reset_index(drop=True)
        hr_df = data_df.copy().reset_index(drop=True)
    

    elif task_type == "ASR":
        start = 1
        k = 4
        target_mask = [1 if i % k == 0 else 0 for i in range(start, ori_df.shape[0])]
        mask_index = [i for i in range(len(target_mask)) if target_mask[i] == 1]
        mean_rows = []
        for i in range(start, ori_df.shape[0], k):
            if i + k <= ori_df.shape[0]:
                group = ori_df.iloc[i: i + k, : ]
                mean = group.mean()
                mean_rows.append(mean)
        mean_df = pd.DataFrame(mean_rows)
        lr_df = mean_df
        lr_df = lr_df.reset_index(drop=True)
        hr_df = data_df.iloc[start : ].reset_index(drop=True)

    return hr_df, lr_df, target_mask, mask_index




def preprocess_ETTh(ori_df, task_type):
    assert task_type in ["SSR", "ASR"]


    # normalizing
    data_df = ori_df.drop(columns=["date"])
    data_df = (data_df - data_df.mean()) / data_df.std()


    if task_type == "SSR" :
        target_mask = [1 if i.split(":")[0][-2 : ] == "00" else 0 for i in ori_df["date"].values]
        mask_index = [i for i in range(len(target_mask)) if target_mask[i] == 1]
        lr_df = data_df.iloc[mask_index, :]
        lr_df = lr_df.reset_index(drop=True)
        hr_df = data_df.copy().reset_index(drop=True)
    

    elif task_type == "ASR":
        start = 1
        k = 24
        target_mask = [1 if i % k == 0 else 0 for i in range(start, ori_df.shape[0])]
        mask_index = [i for i in range(len(target_mask)) if target_mask[i] == 1]
        mean_rows = []
        for i in range(start, ori_df.shape[0], k):
            if i + k <= ori_df.shape[0]:
                group = data_df.iloc[i: i + k, : ]
                mean = group.mean()
                mean_rows.append(mean)
        mean_df = pd.DataFrame(mean_rows)
        lr_df = mean_df
        lr_df = lr_df.reset_index(drop=True)
        hr_df = data_df.iloc[start : ].reset_index(drop=True)

    return hr_df, lr_df, target_mask, mask_index




def preprocess_weather(ori_df, task_type, with_date= False):
    assert task_type in ["SSR", "ASR"]


    # normalizing
    data_df = ori_df.drop(columns=["date"])
    data_df = (data_df - data_df.mean()) / data_df.std()


    if task_type == "SSR":
        target_cols = ["p (mbar)", "T (degC)", "rh (%)", "VPact (mbar)", "wd (deg)", "Tlog (degC)"]
        target_mask = [1 if i.split(":")[1] == "00" else 0 for i in ori_df["date"].values]
        mask_index = [i for i in range(len(target_mask)) if target_mask[i] == 1]
        lr_df = data_df.iloc[mask_index, :]
        lr_df = lr_df[target_cols]
        lr_df = lr_df.reset_index(drop=True)
        hr_df = data_df.copy().reset_index(drop=True)
    

    elif task_type == "ASR":
        ori_df = ori_df.iloc[5 : , :].reset_index(drop=True)
        start = 1
        k = 6
        target_cols = ["p (mbar)", "T (degC)", "rho (g/m**3)", "wv (m/s)", "rain (mm)", "SWDR (W/m???)"]
        target_mask = [1 if i % k == 0 else 0 for i in range(start, ori_df.shape[0])]
        mask_index = [i for i in range(len(target_mask)) if target_mask[i] == 1]
        mean_rows = []
        for i in range(start, ori_df.shape[0], k):
            if i + k <= ori_df.shape[0]:
                group = data_df.iloc[i: i + k, : ]
                mean = group.mean()
                mean_rows.append(mean)
        mean_df = pd.DataFrame(mean_rows)
        mean_df = mean_df[target_cols].reset_index(drop=True)
        lr_df = mean_df
        lr_df = lr_df.reset_index(drop=True)
        hr_df = data_df.iloc[start : ].reset_index(drop=True)

    return hr_df, lr_df, target_mask, mask_index




def read_UEA(data_name, subject_class):
    train_data = []
    with open(f'data/{data_name}/{data_name}_TRAIN.ts', 'r') as f:
        start_data = False
        for line in f:
            line = line.strip()
            if line.startswith('@data'):
                start_data = True
                continue
            if start_data and line:
                split_len = line.split(':')
                label = split_len[-1]
                if label == subject_class :
                    seq_mat = []
                    dim = len(split_len) - 1
                    for dim_i in range(dim):
                        seq = split_len[dim_i]                   
                        seq = [float(x) for x in seq.split(',')]
                        seq_mat.append(np.array(seq))
                    train_data.append(np.array(seq_mat))
    train_data = np.array(train_data).transpose(0, 2, 1)

    test_data = []
    with open(f'data/{data_name}/{data_name}_TEST.ts', 'r') as f:
        start_data = False
        for line in f:
            line = line.strip()
            if line.startswith('@data'):
                start_data = True
                continue
            if start_data and line:
                split_len = line.split(':')
                label = split_len[-1]
                if label == subject_class :
                    seq_mat = []
                    dim = len(split_len) - 1
                    for dim_i in range(dim):
                        seq = split_len[dim_i]                   
                        seq = [float(x) for x in seq.split(',')]
                        seq_mat.append(np.array(seq))
                    test_data.append(np.array(seq_mat))
    test_data = np.array(test_data).transpose(0, 2, 1)

    return train_data, test_data

