import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d, gaussian_filter



def normalize_2d(matrix):
    return (matrix - np.min(matrix, axis=0)) / (np.max(matrix, axis=0) - np.min(matrix, axis=0))

def normalize_columns_to_range(arr, min_val=0.5, max_val=1):
    """
    Normalize each column of a 2D array so that the values in each column fall within the specified range [min_val, max_val].

    Parameters:
        arr (numpy.ndarray): The input 2D array to normalize.
        min_val (float): The minimum value of the target range.
        max_val (float): The maximum value of the target range.

    Returns:
        numpy.ndarray: The normalized 2D array with values in each column scaled to the [min_val, max_val] range.
    """
    # Ensure the input is a NumPy array
    arr = np.array(arr)

    # Find the minimum and maximum values in each column
    col_min = np.min(arr, axis=0)
    col_max = np.max(arr, axis=0)

    # Avoid division by zero by ensuring no column has a constant value
    col_range = col_max - col_min
    col_range[col_range == 0] = 1  # Set the range to 1 for columns with constant values

    # Normalize the array to [0, 1]
    normalized_arr = (arr - col_min) / col_range

    # Scale to the target range [min_val, max_val]
    normalized_arr = normalized_arr * (max_val - min_val) + min_val

    return normalized_arr



def get_load_oncor():
    import ast
    path = './CNEXP0006_kWh_two_week_data.csv'
    df_load = pd.read_csv(path)
    load = df_load.iloc[0:200, :]

    raw_data = [x for x in load["kWh"].apply(ast.literal_eval)]
    raw_time = [x for x in load["time"].apply(ast.literal_eval)]
    # print(time)
    # exit()

    # Find the maximum length of lists
    max_length = max(len(lst) for lst in raw_data)

    # Filter the lists that have the maximum length
    filtered_data = [lst for lst in raw_data if len(lst) == max_length]
    filtered_time = [lst for lst in raw_time if len(lst) == max_length]

    filtered_data = np.array(filtered_data)
    filtered_time = np.array(filtered_time)
    # eliminate rows with many zeros
    # Define a threshold for the maximum number of 0s allowed in a row
    threshold = 1  # This means we allow up to 1 zero in a row
    # Count the number of 0s in each row
    zero_counts = np.count_nonzero(filtered_data == 0, axis=1)

    # Filter out rows that have more than 'threshold' number of 0s
    data = filtered_data[zero_counts <= threshold]
    time = filtered_time[zero_counts <= threshold]

    # Save the array to a CSV file
    np.savetxt("Load_profile1.csv", data, delimiter=",", fmt="%s")
    time = time[0, :].reshape(-1, )
    time = np.array([1 * i for i in range(len(time))])

    data = torch.tensor(data, dtype=torch.float32).T

    data = data[::, :]
    time = time[::, ]

    data = gaussian_filter(data, 0.2)
    t = torch.tensor(time, dtype=torch.float)
    y = torch.tensor(data, dtype=torch.float)

    L_high, L_low = len(t), 160 #######
    t_low_index = np.linspace(0, L_high - 1, L_low, dtype=int)

    t_high = t.numpy()
    t_high = t_high - t_high[0]
    t_low = t_high[t_low_index]

    y_high_split = y[:, 16:17].numpy() #16:18
    y_low_high_res = y[:, 20:21].numpy()  #20: 22
    y_high_split = normalize_columns_to_range(y_high_split)
    y_low_high_res = normalize_columns_to_range(y_low_high_res)

    y_low_split = y_low_high_res[t_low_index, :]
    LR_interval = L_high // L_low

    print("Low rate: {}/min".format(1/(t_low[1] - t_low[0])))
    print("High rate: {}/min".format(1 / (t_high[1] - t_high[0])))

    print("t_high[0] = ", t_high[0])
    print("t_high[-1] = ", t_high[-1])
    print("t_high[1] - t_high[0] = {} min".format(t_high[1] - t_high[0]))
    print("t_low[1] - t_low[0] = {} min".format(t_low[1] - t_low[0]))

    print(t_high.shape, y_high_split.shape, t_low.shape, y_low_split.shape, y_low_high_res.shape, t_low_index.shape)
    print("LR_interval", LR_interval)

    plt.figure(figsize=(10,5))

    plt.plot(t_high, y_low_high_res[:, 0], label='Ground Truth y_low_high_res', linestyle='dashed', color='blue')
    # plt.plot(t_high, y_low_high_res[:, 1], label='Ground Truth y_low_high_res', linestyle='dashed', color='orange')

    plt.plot(t_low, y_low_split[:, 0], 'o', label='Low 1 (observed)', color='blue', alpha=1, )
    # plt.plot(t_low, y_low_split[:, 1], 'o', label='Low 2 (observed)', color='orange', alpha=1, )

    plt.plot(t_high, y_high_split[:, 0], label='Ground Truth y_high', linestyle='-.', color='red')
    # plt.plot(t_high, y_high_split[:, 1], label='Ground Truth y_high', linestyle='-.', color='green')

    plt.legend()
    plt.title("raw data")


    return t_high, y_high_split, t_low, y_low_split, y_low_high_res, t_low_index, LR_interval


def get_data_PV_online_kit(path):
    '''
    y shape: (length, #lines)  #note #lines == #features
    t shape: (length)
    '''
    dataname = "PV_2017"

    df = pd.read_csv(path)
    # data = df[["ShuntPDC_kW_Avg_1", "ShuntPDC_kW_Avg_2", "ShuntPDC_kW_Avg_3", "ShuntPDC_kW_Avg_4","ShuntPDC_kW_Avg_5","ShuntPDC_kW_Avg_6","ShuntPDC_kW_Avg_7"]]
    data = df[["ShuntPDC_kW_Avg_4", "ShuntPDC_kW_Avg_5", "ShuntPDC_kW_Avg_6", "ShuntPDC_kW_Avg_7"]]  # only use 4 curves
    data = data.values.astype(np.float32)

    time = df[["TIMESTAMP"]]
    time = time.values.reshape(-1, )  # 1 sample/min by observation
    time = np.array([i for i in range(len(time))]) # so we assign manually


    data = data[500: 1000]
    time = time[500: 1000]
    data = data[::5, :]
    time = time[::5, ]

    data = normalize_columns_to_range(data,min_val=0.5, max_val=1)

    # data = gaussian_filter(data, 1)  #### todo gaussian fliter
    y = torch.tensor(data, dtype=torch.float)
    t = torch.tensor(time, dtype=torch.float)

    L_high, L_low = len(t), 8 #######
    t_low_index = np.linspace(0, L_high - 1, L_low, dtype=int)

    t_high = t.numpy()
    t_high = t_high - t_high[0]
    t_low = t_high[t_low_index]

    y_high_split = y[:, 0:2].numpy()
    y_low_high_res = y[:, 2:4].numpy()
    y_low_split = y_low_high_res[t_low_index, :]
    LR_interval = L_high // L_low


    return t_high, y_high_split, t_low, y_low_split, y_low_high_res, t_low_index, LR_interval



def get_data_PV_online():
    # pv data
    base_path = "./2017/01/onemin-Ground-2017-01-"

    # Initialize lists to store results
    high_res_data_list = []
    low_res_data_list = []
    low_to_high_res_data_list = []
    downsample_index_list = []

    # Initialize offset for downsample_index
    offset = 0
    # Loop through numbers 1 to 8 and construct file paths
    day_profile = 11 # input 10 days' data
    for i in range(1, day_profile):
        if i < 10:
            path = f"{base_path}{0}{i}.csv"
        else:
            path = f"{base_path}{i}.csv"
        t_high1, high_res_data1, t_low1, low_res_data1, low_to_high_res_data1, downsample_index1, lr_interval = get_data_PV_online_kit(path)
        high_res_data_list.append(high_res_data1)
        low_res_data_list.append(low_res_data1)
        low_to_high_res_data_list.append(low_to_high_res_data1)

        # Adjust downsample_index with the current offset and add to the list
        adjusted_downsample_index = downsample_index1 + offset
        downsample_index_list.append(adjusted_downsample_index)

        # Update the offset based on the current high_res_data length
        offset += high_res_data1.shape[0]

    # Concatenate the data lists into single arrays
    high_res_data = np.concatenate(high_res_data_list, axis=0)
    low_res_data = np.concatenate(low_res_data_list, axis=0)
    low_to_high_res_data = np.concatenate(low_to_high_res_data_list, axis=0)
    downsample_index = np.concatenate(downsample_index_list, axis=0)

    t_high = np.array([i for i in range(1000)])
    t_low = t_high[downsample_index]

    t_high, y_high_split, t_low, y_low_split, y_low_high_res, t_low_index, LR_interval = t_high, high_res_data, t_low, low_res_data, low_to_high_res_data, downsample_index, lr_interval
    return t_high, y_high_split, t_low, y_low_split, y_low_high_res, t_low_index, LR_interval


def get_data_200bus_9_10_kit(path):
    '''
    y shape: (length, #lines)  #note #lines == #features
    t shape: (length)
    '''
    print("data from utils")
    dataname = "200_bus_9_10"
    df = pd.read_csv(path)

    format_str = '%Y-%m-%d %H:%M:%S:%f'
    df['Time'] = pd.to_datetime(df['Time'], format=format_str)

    def extract_seconds(dt):
        return dt.second + dt.microsecond / 1000000

    df['seconds'] = df['Time'].apply(extract_seconds)

    t = df["seconds"].values
    y = df[["XData", "YData", "XData2", "ZData"]].values
    # print(t)

    t = t[180:280]
    y = y[180:280,:]

    y = normalize_columns_to_range(y, min_val=0.5, max_val=1)
    t = torch.tensor(t, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.float)

    L_high, L_low = len(t), 20 #######
    t_low_index = np.linspace(0, L_high - 1, L_low, dtype=int)

    t_high = t.numpy()
    t_high = t_high - t_high[0]
    t_low = t_high[t_low_index]

    y_high_split = y[:, 0:1].numpy()
    y_low_high_res = y[:, 1:2].numpy()
    y_low_split = y_low_high_res[t_low_index, :]
    LR_interval = L_high // L_low

    print("Low rate: {}/s".format(1/(t_low[1] - t_low[0])))
    print("High rate: {}/s".format(1 / (t_high[1] - t_high[0])))
    print("t_high[0]", t_high[0])
    print("t_high[-1]", t_high[-1])


    plt.figure(figsize=(10,5))

    plt.plot(t_high, y_low_high_res[:, 0], label='Ground Truth y_low_high_res', linestyle='dashed', color='blue')
    # plt.plot(t_high, y_low_high_res[:, 1], label='Ground Truth y_low_high_res', linestyle='dashed', color='orange')

    plt.plot(t_low, y_low_split[:, 0], 'o', label='Low 1 (observed)', color='blue', alpha=1, )
    # plt.plot(t_low, y_low_split[:, 1], 'o', label='Low 2 (observed)', color='orange', alpha=1, )

    plt.plot(t_high, y_high_split[:, 0], label='Ground Truth y_high', linestyle='-.', color='red')
    # plt.plot(t_high, y_high_split[:, 1], label='Ground Truth y_high', linestyle='-.', color='green')

    plt.legend()
    plt.title("raw data")

    return t_high, y_high_split, t_low, y_low_split, y_low_high_res, t_low_index, LR_interval



def get_data_200bus_9_10():
    base_path = "./VA-event3-"

    # Initialize lists to store results
    high_res_data_list = []
    low_res_data_list = []
    low_to_high_res_data_list = []
    downsample_index_list = []

    # Initialize offset for downsample_index
    offset = 0
    # Loop through numbers 1 to 8 and construct file paths
    day_profile = 4  # input 3 days' data
    for i in range(1, day_profile):
        path = f"{base_path}{i}.csv"
        t_high1, high_res_data1, t_low1, low_res_data1, low_to_high_res_data1, downsample_index1, lr_interval = get_data_200bus_9_10_kit(path)
        high_res_data_list.append(high_res_data1)
        low_res_data_list.append(low_res_data1)
        low_to_high_res_data_list.append(low_to_high_res_data1)

        # Adjust downsample_index with the current offset and add to the list
        adjusted_downsample_index = downsample_index1 + offset
        downsample_index_list.append(adjusted_downsample_index)

        # Update the offset based on the current high_res_data length
        offset += high_res_data1.shape[0]

    # Concatenate the data lists into single arrays
    high_res_data = np.concatenate(high_res_data_list, axis=0)
    low_res_data = np.concatenate(low_res_data_list, axis=0)
    low_to_high_res_data = np.concatenate(low_to_high_res_data_list, axis=0)
    downsample_index = np.concatenate(downsample_index_list, axis=0)


    t_high = np.array([i for i in range(300)])
    t_low = t_high[downsample_index]

    t_high, y_high_split, t_low, y_low_split, y_low_high_res, t_low_index, LR_interval = t_high, high_res_data, t_low, low_res_data, low_to_high_res_data, downsample_index, lr_interval
    return t_high, y_high_split, t_low, y_low_split, y_low_high_res, t_low_index, LR_interval



def get_data_ari_quality():
    # path = "C:\\Software\\Geometric-DL\\HR-LR-DL\\Data\\AirQualityUCI.csv"
    path = "./AirQualityUCI.csv"
    df = pd.read_csv(path)

    format_str = '%Y-%m-%d %H:%M:%S:%f'
    t = pd.to_datetime(df['Time'])

    def extract_seconds(dt):
        return dt.second + dt.microsecond / 1000000

    # y = df[["XData", "YData", "ZData", "XData2"]].values
    # y = df[["XData", "XData2", "XData", "XData2"]].values
    y = df[["C6H6(GT)"]].values
    # print(t)

    t = t[5000:6600]
    y = y[5000:6600, :]
    y = gaussian_filter(y, 2)  #### todo gaussian fliter

    # Reshape to (N, 2)
    y = y.reshape(-1, 2)

    y = normalize_columns_to_range(y, min_val=0.5, max_val=1)

    L_high, L_low = y.shape[0], 80  #######
    t_low_index = np.linspace(0, L_high - 1, L_low, dtype=int)
    time = np.array([i for i in range(L_high)]) # so we assign manually

    t_high = time
    t_low = t_high[t_low_index]

    y_high_split = y[:, 0:1]
    y_low_high_res = y[:, 1:2]
    y_low_split = y_low_high_res[t_low_index, :]
    LR_interval = L_high // L_low

    print("Low rate: {}/s".format(1 / (t_low[1] - t_low[0])))
    print("High rate: {}/s".format(1 / (t_high[1] - t_high[0])))
    print("t_high[0]", t_high[0])
    print("t_high[-1]", t_high[-1])

    plt.figure(figsize=(10, 5))

    plt.plot(t_high, y_low_high_res[:, 0], label='Ground Truth y_low_high_res', linestyle='dashed', color='blue')
    # plt.plot(t_high, y_low_high_res[:, 1], label='Ground Truth y_low_high_res', linestyle='dashed', color='orange')

    plt.plot(t_low, y_low_split[:, 0], 'o', label='Low 1 (observed)', color='blue', alpha=1, )
    # plt.plot(t_low, y_low_split[:, 1], 'o', label='Low 2 (observed)', color='orange', alpha=1, )

    plt.plot(t_high, y_high_split[:, 0], label='Ground Truth y_high', linestyle='-.', color='red')
    # plt.plot(t_high, y_high_split[:, 1], label='Ground Truth y_high', linestyle='-.', color='green')

    plt.legend()
    plt.title("raw data")

    return t_high, y_high_split, t_low, y_low_split, y_low_high_res, t_low_index, LR_interval




def create_sequences(data, seq_length, value_high_real = None):
    '''
    data: LR after interpolation: used as output point
    value_high_real: real HR data, as input seq
    '''
    sequences = []
    targets = []
    for i in range(len(data) - seq_length):

        if value_high_real is None:
            seq = data[i:i + seq_length]
        else:
            seq = value_high_real[i:i + seq_length]

        target = data[i + seq_length]

        sequences.append(seq)
        targets.append(target)
    return torch.stack(sequences), torch.stack(targets)


# Function to generate a spiral dataset
def generate_spiral_time_series(n_points, n_spirals, noise=0.00):
    t_end = 20 * np.pi
    t = np.linspace(0, t_end,  n_points)   # Time variable
    x = t * (np.cos(t))  # X coordinates (spiral pattern)
    y = t * (np.sin(t))   # Y coordinates (spiral pattern)

    x, y = x / (t_end * np.pi) + 1, y / (t_end * np.pi) + 1  # normalization

    # Adding noise to make the dataset more realistic
    x += noise * np.random.randn(n_points)
    y += noise * np.random.randn(n_points)

    # For time series, we can consider 't' as the time variable and (x, y) as values over time
    time_series = np.vstack((t, x, y)).T  # Shape: (n_points, 3), where time, x, and y are columns

    return time_series


def get_data_spiral():
    # Generate and visualize the spiral time series dataset
    n_points = 500  # Number of points in the time series
    n_spirals = 1  # Single spiral for simplicity
    spiral_data = generate_spiral_time_series(n_points, n_spirals)

    print(spiral_data.shape)

    t_high = spiral_data[:, 0]
    y_high_split = spiral_data[:, [1]]
    y_low_high_res = spiral_data[:, [2]]

    L_high, L_low = len(y_high_split), 20
    t_low_index = np.linspace(0, L_high - 1, L_low, dtype=int)

    t_low = t_high[t_low_index]
    y_low_split = y_low_high_res[t_low_index, :]

    LR_interval = L_high // L_low

    # Visualize the spiral as a 2D plot (x and y coordinates over time)
    plt.figure(figsize=(8, 6))
    plt.plot(y_high_split, y_low_high_res, label="Spiral", color="g")
    plt.title("Spiral Time Series Visualization")
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.legend()

    plt.figure(figsize=(10, 5))
    plt.plot(t_high, y_low_high_res[:, 0], label='Ground Truth y_low_high_res', linestyle='dashed', color='blue')
    plt.plot(t_low, y_low_split[:, 0], 'o', label='Low 1 (observed)', color='blue', alpha=1, )
    plt.plot(t_high, y_high_split[:, 0], label='Ground Truth y_high', linestyle='-.', color='red')
    plt.legend()
    plt.title("raw data")
    # plt.show()


    return t_high.astype(np.float32), y_high_split.astype(np.float32), t_low.astype(np.float32), y_low_split.astype(np.float32), y_low_high_res.astype(np.float32), t_low_index.astype(np.int32), LR_interval



if __name__ == "__main__":
    get_data_spiral()
    pass











