from set_param import *
from utils import slice_dataset, normalisation
from RFM import RFM
from models.fcst_model_MLP import TimeSeriesMLP
from models.fcst_model_LSTM import LSTMSequenceGenerator
from models.fcst_model_UNet import UNet1D
from models.fcst_model_TCN import Predictor
from models.fcst_model_transformer import TimeSeriesGenerator
from models.fcst_model_VP import CrossMaskedDecoder
from models.fcst_model_VP_large import CrossMaskedDecoder as DecoderLarge
from decomposition import Decomposition


# preprocess data
SEGMENT_NUM = data_parser[args.dataset][args.task_type]["segment_num"]
TRAIN_SCALE = data_parser[args.dataset][args.task_type]["scale"]
SEQ_HR = SEGMENT_NUM * TRAIN_SCALE + 1
SEQ_LR = SEGMENT_NUM + 1
if args.dataset in ["MotorImagery", "SelfRegulationSCP1", "SelfRegulationSCP2"] :  
    data = data[ : , : SEQ_HR, :]
print("original data : ", data.shape)

if args.dataset not in ["MotorImagery", "SelfRegulationSCP1", "SelfRegulationSCP2"] :
    data, mean_list, std_list = normalisation(data)
    data = slice_dataset(data, SEQ_HR)
    print("sliced data : ", data.shape)
else :
    mean = data.mean(axis=1, keepdims=True)     
    std = data.std(axis=1, keepdims=True) + 1e-6
    normalized_data = (data - mean) / std
    data = normalized_data
    print("normalized data : ", data.shape)
test_num = int(data.shape[0] * 0.2)
if args.shuffle :
    np.random.shuffle(data)
test_data = data[ : test_num, : , : ]
train_data = data[test_num : , : , : ]
print("train data : ", train_data.shape)
print("test data : ", test_data.shape)
print()
tssr_gt = test_data.copy()

lr_sample_idx = range(0, test_data.shape[1], data_parser[args.dataset][args.task_type]["scale"])
if args.task_type == "SSR" :
    test_lr = test_data[ : , lr_sample_idx, : ]
else :
    start_idx = np.concatenate([np.array([-1]), np.array(lr_sample_idx[:-1])]) + 1
    l = []
    for i in range(len(lr_sample_idx)):
        s, e = start_idx[i], lr_sample_idx[i] + 1   
        avg = test_data[:, s:e, :].mean(axis=1)
        l.append(avg)
    test_lr = np.stack(l, axis=1)


# init predictor
INPUT_DIM = data_parser[args.dataset][args.task_type]["input_dim"]
if args.predictor == "VP" :
    if args.version == "standard" :
        velocity_predictor = CrossMaskedDecoder(
            x_dim= INPUT_DIM, 
            c_dim= INPUT_DIM * 2 if args.use_both else INPUT_DIM, 
            d_model= args.h_dim, 
            d_ff= args.ff_dim,
            n_heads= args.n_head, 
            num_layers= args.decoder_layer, 
            itf_dim= args.itf_dim,
            itf_hidden= args.itf_hidden,
            itf_schema= args.itf_schema + [SEQ_HR],
            task_type= args.task_type,
            device= args.device,
            itf= args.use_itf,
            args= args,
        )
    elif args.version == "large" :
        velocity_predictor = DecoderLarge(
            x_dim= INPUT_DIM, 
            c_dim= INPUT_DIM * 2 if args.use_both else INPUT_DIM, 
            d_model= 512, 
            d_ff= 512,
            n_heads= 16, 
            num_layers= 8, 
            itf_dim= args.itf_dim,
            itf_hidden= args.itf_hidden,
            itf_schema= args.itf_schema + [SEQ_HR],
            task_type= args.task_type,
            device= args.device,
            itf= args.use_itf,
            args= args,
        )
    else : raise
elif args.predictor == "Transformer" :
    velocity_predictor = TimeSeriesGenerator(
        input_dim= INPUT_DIM , 
        cond_dim= INPUT_DIM, 
        itf_dim= args.itf_dim,
        itf_hidden= args.itf_hidden,
        itf_schema= args.itf_schema + [SEQ_HR],
        device= args.device,
    )
elif args.predictor == "TCN" :
    velocity_predictor = Predictor(
        input_dim= INPUT_DIM,       
        cond_dim= INPUT_DIM,
        itf_dim= args.itf_dim,
        itf_hidden= args.itf_hidden,
        seq_len= SEQ_HR,
        itf_schema= args.itf_schema + [SEQ_HR],
        device= args.device,
        itf= args.use_itf,
    )
elif args.predictor == "MLP" :
    velocity_predictor = TimeSeriesMLP(
        feature_dim= INPUT_DIM,
        time_length= SEQ_HR,
        itf_dim= args.itf_dim,
        itf_hidden= args.itf_hidden,
        itf_schema= args.itf_schema + [SEQ_HR],
        device= args.device,
    )
elif args.predictor == "LSTM" :
    velocity_predictor = LSTMSequenceGenerator(
        input_dim= INPUT_DIM, 
        covariate_dim= INPUT_DIM,
        output_dim= INPUT_DIM,
        itf_dim= args.itf_dim,
        itf_hidden= args.itf_hidden,
        itf_schema= args.itf_schema + [SEQ_HR],
        device= args.device,
    )
elif args.predictor == "UNet" :
    velocity_predictor = UNet1D(
        input_dim= INPUT_DIM,
        covariate_dim= INPUT_DIM,
        output_dim= INPUT_DIM,
        itf_dim= args.itf_dim,
        itf_hidden= args.itf_hidden,
        itf_schema= args.itf_schema + [SEQ_HR],
        device= args.device,
    )
else : raise

# init rectified flow
gen_model_s = RFM(
    model= velocity_predictor,
    model_type= args.predictor,
    sr_ratio= TRAIN_SCALE,
    device= args.device,
    target_type= "s",
    use_both= args.use_both,
)
gen_model_t = RFM(
    model= velocity_predictor,
    model_type= args.predictor,
    sr_ratio= TRAIN_SCALE,
    device= args.device,
    target_type= "t",
    use_both= args.use_both,
)

# init decomposition
decomp_f = Decomposition(
    kernel_size= data_parser[args.dataset][args.task_type]["decomp_f"],
)
decomp_c = Decomposition(
    kernel_size= data_parser[args.dataset][args.task_type]["decomp_c"],
)