import pdb
import numpy as np
import pandas as pd
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from utils.timefeatures import time_features
import warnings
import random
import torch
from scipy.stats import linregress
import copy
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt



class Dataset_Generator(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, timeenc=0, freq='h'):
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.root_path = root_path
        self.data_path = data_path
        self.all_x = []
        self.all_y = []
        self.all_orders = []
        self.timestamp = []
        self.all_prompt_bank = []
        self.__read_data__()

    def _shuffle(self):
        random.shuffle(self.all_pairs)

    def __read_data__(self):
        self.scaler = StandardScaler()
        file_path = os.path.join(self.root_path, self.data_path)
        data = pd.read_csv(file_path)
        print("===========loading file===========,",os.path.join(self.root_path, self.data_path))
        csv_file = file_path.split('.')[0].split('/')[-1]
        self._generate_pairs(data, csv_file)
        self.all_x = np.stack(self.all_x, axis=0)
        self.all_y = np.stack(self.all_y, axis=0)
        self.all_orders = np.stack(self.all_orders, axis=0)
    def _generate_pairs(self, df_raw, csv_file):
        scaler = StandardScaler()
        cols = list(df_raw.columns)
        cols.remove(self.target)
        cols.remove('date')
        df_raw = df_raw[['date'] + cols + [self.target]]
        num_train = int(len(df_raw) * 0.6)
        num_test = int(len(df_raw) * 0.2)
        num_vali = len(df_raw) - num_train - num_test
        if csv_file in ['ETTh1', 'ETTh2']:
            border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
            border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
            # border1s = [0, 0, 0]
            # border2s = [12 * 20, 12 * 20, 12 * 20]
                        
        elif csv_file in ['ETTm1', 'ETTm2']:
            border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
            border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
        else:
            border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
            border2s = [num_train, num_train + num_vali, len(df_raw)]

        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]
        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            scaler.fit(train_data.values)
            data = scaler.transform(df_data.values)
        else:
            data = df_data.values
        orders = ['increase amplitude']
        orders = random.sample(orders,len(orders))
        for i in range(border1, border2 - self.pred_len - self.seq_len - self.label_len):
            batch_x_i = data[i:i + self.seq_len].T
            batch_y_i = data[i + self.seq_len:i + self.pred_len + self.seq_len + self.label_len].T
            n_vars = batch_x_i.shape[0]
            new_seq_y2 = None
            for ord_id in range(len(orders)):
                order = orders[ord_id]
                time = np.arange(self.pred_len + self.label_len)
                self.all_x.append(batch_x_i)
                if order in['linear growth', 'linear trend up','linear upward'] :
                    slope = np.random.uniform(0.01, 0.015)
                    new_seq_y = copy.deepcopy(batch_y_i) 
                    batch_y_i_copy = copy.deepcopy(batch_y_i) 
                    for yidx in range(n_vars):
                        slope, intercept, _, _, _ = linregress(time, batch_y_i[yidx])
                        batch_y_i_copy[yidx] += -slope*time
                        new_seq_y[yidx] = batch_y_i_copy[yidx] + np.random.uniform(0.007, 0.01)*time + np.random.uniform(-0.05, 0.04, 96)
                        batch_y_i_copy[yidx] = batch_y_i_copy[yidx] * 0.35 + np.random.uniform(2, 2.05, 96)
                        batch_y_i_copy[yidx] += np.concatenate([np.random.uniform(0.2, 0.3, 48),np.random.uniform(-0.1, 0.0, 48)]) + np.random.uniform(0.005, 0.006)*time #+ np.exp(0.002 * time)
                elif order in[ 'linear decay', 'linear trend down']:
                    slope = np.random.uniform(0.01, 0.015)
                    new_seq_y = copy.deepcopy(batch_y_i[:])
                    batch_y_i_copy = copy.deepcopy(batch_y_i) 
                    for yidx in range(n_vars):
                        slope, intercept, _, _, _ = linregress(time, batch_y_i[yidx])
                        slope =  0-slope #- np.random.uniform(0.01, 0.015)
                        batch_y_i_copy[yidx] += slope*time
                        new_seq_y[yidx] = batch_y_i_copy[yidx] * 0.55 - np.random.uniform(0.01, 0.015)*time + np.random.uniform(1.3, 1.33, 96)
                        batch_y_i_copy[yidx] = batch_y_i_copy[yidx] + np.random.uniform(-0.05, 0.04, 96)
                        batch_y_i_copy[yidx] += np.concatenate([np.random.uniform(-0.05, 0.0, 48),np.random.uniform(0.01, 0.01, 48)]) - np.random.uniform(0.002, 0.0025)*time
                elif order in ['increase amplitude','amplitude increase','raise amplitude']:
                    batch_y_i_copy = copy.deepcopy(batch_y_i)*0.65 + np.random.uniform(-0.22, -0.2, 96)
                    scale = np.random.uniform(1.4, 1.5)  # 随机选择振幅增加比例
                    scale2 = np.random.uniform(1.5,2.0)
                    new_seq_y = batch_y_i * scale + np.random.uniform(-1, 1, 96)
                elif order in ['amplitude decrease','decrease amplitude','reduce amplitude']:
                    batch_y_i_copy = copy.deepcopy(batch_y_i) + np.random.uniform(0.3, 0.33, 96)
                    scale = np.random.uniform(0.3, 0.6)  # 随机选择振幅增加比例
                    new_seq_y = batch_y_i * scale + np.random.uniform(-0.1, 0.4, 96) + np.random.uniform(-0.25, 0.25, 96)
                elif 'exponential growth' in order or 'exponential goes up' in order or 'exponential upward' in order:
                    growth_rate = np.random.uniform(0.006, 0.01)
                    addx = np.exp(growth_rate * time)
                    batch_y_i_copy = copy.deepcopy(batch_y_i)*0.45 + np.random.uniform(1.5, 1.51, 96)
                    new_seq_y = batch_y_i +  np.tile(batch_x_i[:,-1][:, np.newaxis], (1, 96)) - addx[0] + addx - np.tile(batch_y_i[:,0][:, np.newaxis], (1, 96))  + np.random.uniform(-0.012, 0.012, 96)

                    batch_y_i_copy = batch_y_i_copy + np.random.uniform(-0.01, 0.03, 96) + np.random.uniform(-0.051, 0.21, 96)
                elif 'logarithmic growth' in order or 'logarithmic upward' in order:
                    growth_rate = np.random.uniform(0.015, 0.2)
                    addx = np.log1p(growth_rate * time)
                    batch_y_i_copy = copy.deepcopy(batch_y_i) *0.45 + np.random.uniform(2.1, 2.21, 96)
                    new_seq_y = batch_y_i + np.tile(batch_x_i[:,-1][:, np.newaxis], (1, 96))  - addx[0] + addx - np.tile(batch_y_i[:,0][:, np.newaxis], (1, 96)) + np.random.uniform(-0.1, 0.02, 96)
                elif 'exponential decay' in order or 'exponential goes down' in order:
                    batch_y_i_copy = copy.deepcopy(batch_y_i) *0.45 + np.random.uniform(1.5, 1.51, 96)
                    decay_rate = np.random.uniform(0.008, 0.015)
                    addx = np.exp(decay_rate * time)
                    new_seq_y = batch_y_i + np.tile(batch_x_i[:,-1][:, np.newaxis], (1, 96)) - addx + addx[0] - np.tile(batch_y_i[:,0][:, np.newaxis], (1, 96)) + np.random.uniform(-0.2, 0.2, 96)
                elif 'logarithmic decay' in order or 'logarithmic downward' in order:
                    decay_rate = np.random.uniform(0.015, 0.20)
                    addx = np.log1p(decay_rate * time) 
                    batch_y_i_copy = copy.deepcopy(batch_y_i)*0.45 + np.random.uniform(1.6, 1.67, 96)
                    new_seq_y = batch_y_i + np.tile(batch_x_i[:,-1][:, np.newaxis], (1, 96)) - addx + addx[0] - np.tile(batch_y_i[:,0][:, np.newaxis], (1, 96)) + np.random.uniform(-0.01, 0.02, 96)
                elif 'linear growth and linear decay' in order or 'linear goes up and linear goes down' in order:
                    transition_time = int(0.5 * (self.pred_len + self.label_len))
                    new_seq_y = copy.deepcopy(batch_y_i) 
                    batch_y_i_copy = copy.deepcopy(batch_y_i) 
                    for yidx in range(n_vars):
                        slope, intercept, _, _, _ = linregress(time, batch_y_i[yidx])    
                        new_seq_y[yidx] = copy.deepcopy(batch_y_i[yidx] - slope*time)
                        batch_y_i_copy[yidx] = (batch_y_i_copy[yidx] - slope*time) * 0.33 + np.random.uniform(2.6, 2.68, 96)
                        
                        initial_slope = np.random.uniform(0.006, 0.01)
                        decline_slope = np.random.uniform(0.01, 0.016)
                        trend = np.concatenate([
                            initial_slope * time[:transition_time],
                            initial_slope * time[transition_time] - decline_slope * (time[transition_time:] - time[transition_time])
                        ])
                        new_seq_y[yidx] = new_seq_y[yidx] + trend
                elif 'linear decay and linear growth' in order or 'linear goes down and linear goes up' in order:

                    transition_time = int(0.5 * (self.pred_len + self.label_len))
                    new_seq_y = copy.deepcopy(batch_y_i) 
                    batch_y_i_copy = copy.deepcopy(batch_y_i) 
                    for yidx in range(n_vars):
                        slope, intercept, _, _, _ = linregress(time, batch_y_i[yidx])           
                        new_seq_y[yidx] = batch_y_i[yidx] + slope*time
                        batch_y_i_copy[yidx] = (batch_y_i_copy[yidx] - slope*time) * 0.23 + np.random.uniform(2.7, 2.76, 96)
                        initial_slope = np.random.uniform(0.02, 0.03)-slope
                        increase_slope = np.random.uniform(0.02, 0.03)-slope
                        trend = np.concatenate([
                            -initial_slope * time[:transition_time],
                            -initial_slope * time[transition_time] + increase_slope * (time[transition_time:] - time[transition_time])
                        ])
                        new_seq_y[yidx] = batch_y_i[yidx] + trend
                elif 'keep stable' in order:
                    new_seq_y = copy.deepcopy(batch_y_i) 
                    batch_y_i_copy = copy.deepcopy(batch_y_i) 
                    for yidx in range(n_vars):
                        slope, intercept, _, _, _ = linregress(time, batch_y_i[yidx])
                        #slope = 0-slope #np.random.uniform(0.01, 0.015)
                        batch_y_i_copy[yidx] += -slope*time
                        new_seq_y[yidx] = batch_y_i_copy[yidx]* 0.2 +  np.random.uniform(3, 3.02, 96)  + np.random.uniform(0.0001, 0.0005)*time 
                        batch_y_i_copy[yidx] = batch_y_i_copy[yidx] * 0.65
                        batch_y_i_copy[yidx] += np.concatenate([np.random.uniform(0.01, 0.02, 48),np.random.uniform(0.2, 0.21, 48)]) + np.random.uniform(1.25, 1.26,96) #+ np.exp(0.002 * time)
                else:
                    new_seq_y = batch_y_i

                self.all_y.append(new_seq_y)
                self.all_orders.append(np.array([order]*n_vars))

    def __getitem__(self, index):
        seq_x = self.all_x[index]
        seq_y = self.all_y[index]
        order = self.all_orders[index]
        return seq_x, seq_y, order, []  # seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.all_x)

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)
    
