import argparse
import pandas as pd
from preprocessing import *


data_parser = {
    "ETTm1" : {
        "SSR" : {
            "scale" : 4,
            "segment_num" : 48,
            "decomp_f" : 25,
            "decomp_c" : 25,
            "sample_drift" : 0,
            "input_dim" : 7,
        }, 
        "ASR" : {
            "scale" : 4,
            "segment_num" : 48,
            "decomp_f" : 25,
            "decomp_c" : 25,
            "sample_drift" : 0,
            "input_dim" : 7,
        }},
    "ETTm2" : {
        "SSR" : {
            "scale" : 4,
            "segment_num" : 48,
            "decomp_f" : 25,
            "decomp_c" : 25,
            "sample_drift" : 0,
            "input_dim" : 7,
        }, 
        "ASR" : {
            "scale" : 4,
            "segment_num" : 48,
            "decomp_f" : 25,
            "decomp_c" : 25,
            "sample_drift" : 0,
            "input_dim" : 7,
        }},
    "weather" : {
        "SSR" : {
            "scale" : 6,
            "segment_num" : 48,
            "decomp_f" : 25,
            "decomp_c" : 25,
            "sample_drift" : 0,
            "input_dim" : 6,
        }, 
        "ASR" : {
            "scale" : 6,
            "segment_num" : 48,
            "decomp_f" : 25,
            "decomp_c" : 25,
            "sample_drift" : 0,
            "input_dim" : 6,
        }},
    "PEMS-SF" : {
        "SSR" : {
            "scale" : 6,
            "segment_num" : 48,
            "decomp_f" : 17,
            "decomp_c" : 17,
            "sample_drift" : 0,
            "input_dim" : 9,
        },
        "ASR" : {
            "scale" : 6,
            "segment_num" : 48,
            "decomp_f" : 17,
            "decomp_c" : 17,
            "sample_drift" : 0,
            "input_dim" : 9,
        }},
    "etth1" : {
        "SSR" : {
            "scale" : 12,
            "segment_num" : 12,
            "decomp_f" : 7,
            "decomp_c" : 7,
            "sample_drift" : 0,
            "input_dim" : 7,
        }, 
        "ASR" : {
            "scale" : 12,
            "segment_num" : 12,
            "decomp_f" : 7,
            "decomp_c" : 7,
            "sample_drift" : 8,
            "input_dim" : 7,
        }},
    "etth2" : {
        "SSR" : {
            "scale" : 12,
            "segment_num" : 12,
            "decomp_f" : 7,
            "decomp_c" : 7,
            "sample_drift" : 0,
            "input_dim" : 7,
        }, 
        "ASR" : {
            "scale" : 12,
            "segment_num" : 12,
            "decomp_f" : 7,
            "decomp_c" : 7,
            "sample_drift" : 8,
            "input_dim" : 7,
        }},
    "SelfRegulationSCP1" : {
        "SSR" : {
            "scale" : 4,
            "segment_num" : 200,
            "decomp_f" : 51,
            "decomp_c" : 51,
            "sample_drift" : 0,
            "input_dim" : 6,
        }, 
        "ASR" : {
            "scale" : 4,
            "segment_num" : 200,
            "decomp_f" : 51,
            "decomp_c" : 51,
            "sample_drift" : 0,
            "input_dim" : 6,
        }},
    "SelfRegulationSCP2" : {
        "SSR" : {
            "scale" : 4,
            "segment_num" : 200,
            "decomp_f" : 51,
            "decomp_c" : 51,
            "sample_drift" : 0,
            "input_dim" : 7,
        }, 
        "ASR" : {
            "scale" : 4,
            "segment_num" : 200,
            "decomp_f" : 51,
            "decomp_c" : 51,
            "sample_drift" : 0,
            "input_dim" : 7,
        }},
    "MotorImagery" : {
        "SSR" : {
            "scale" : 10,
            "segment_num" : 100,
            "decomp_f" : 17,
            "decomp_c" : 17,
            "sample_drift" : 0,
            "input_dim" : 4,
        }, 
        "ASR" : {
            "scale" : 10,
            "segment_num" : 100,
            "decomp_f" : 17,
            "decomp_c" : 17,
            "sample_drift" : 0,
            "input_dim" : 4,
        }}
}


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, required=False, default='ETTm2', help='The dataset name.')
parser.add_argument('--version', type=str, required=False, default='standard', help='Standard SRT or SRT-large.')
parser.add_argument('--task_type', type=str, required=False, default='SSR', help='The task type. This can be set to SSR or ASR.')
parser.add_argument('--sample_step', type=int, required=False, default=4, help='The steps for sampling.')
parser.add_argument('--decoder_layer', type=int, required=False, default=3, help='The number of layers in the decoder.')
parser.add_argument("--n_head", type=int, required=False, default=4, help='The number of heads in the multi-head attention layer.')
parser.add_argument("--ff_dim", type=int, required=False, default=128, help='The dimension of the feedforward layer.')
parser.add_argument("--h_dim", type=int, required=False, default=128, help='The hidden dimension of model.')
parser.add_argument('--device', type=str, required=False, default='cuda', help='The device to use for training.')
parser.add_argument("--use_itf", type=bool, required=False, default=True, help='Whether to use ITF.')
parser.add_argument('--itf_schema', type=list, required=False, default=[128], help='The schema of the ITF.')
parser.add_argument('--itf-hidden', type=int, required=False, default=128, help='The hidden dimension of the ITF.')
parser.add_argument('--itf_dim', type=int, required=False, default=3, help='The dimension of the ITF.')
parser.add_argument('--predictor', type=str, required=False, default='VP', help='The velocity predictor model.')
parser.add_argument('--epoch', type=int, required=False, default=200, help='The number of epochs for training.')
parser.add_argument('--sinu_pe', type=bool, required=False, default=False, help='Whether to use sinusoidal positional encoding.')
parser.add_argument("--periodicity", type=int, required=False, default=0, help='The prior periodicity of the data.')
parser.add_argument('--unfold_dim', type=str, required=False, default='self', help='The dimension to unfold.')
parser.add_argument('--unfold_style', type=str, required=False, default='one', help='The style of unfolding.')
parser.add_argument("--time_emd", type=bool, required=False, default=False, help='Whether to use time embedding.')
parser.add_argument('--t_dim', type=int, required=False, default=128, help='The dimension of time embedding.')
parser.add_argument('--re_loss', type=bool, required=False, default=False, help='Whether to use reconstruction loss.')
parser.add_argument('--sample_batch_size', type=int, required=False, default=16, help='The batch size for sampling.')
parser.add_argument('--batch_size', type=int, required=False, default=32, help='The batch size for training.')
parser.add_argument('--use_both', type=bool, required=False, default=False, help='Whether to use both conditions without disentanglement.')
parser.add_argument('--shuffle', type=bool, required=False, default=True, help='Whether to shuffle the dataset.')
parser.add_argument('--retrain', type=bool, required=False, default=False, help='Whether to retrain SRT-large.')
parser.add_argument('--refine', type=bool, required=False, default=False, help='Whether to refine the gaps.')
parser.add_argument('--verbose', type=bool, required=False, default=False, help='Whether to print verbose output.')
args = parser.parse_args()


if args.dataset in ["ETTm1", "ETTm2", "ETTh1", "ETTh2"]:
    data = pd.read_csv(f"data/{args.dataset}.csv").values[ : , 1 : ]
elif args.dataset in ["SelfRegulationSCP1", "SelfRegulationSCP2"]:
    train_data, test_data = read_UEA(args.dataset, "negativity")
    data = np.concatenate([train_data, test_data], axis=0)
elif args.dataset in ["MotorImagery"]:
    train_data, test_data = read_UEA(args.dataset, "finger")
    data = np.concatenate([train_data, test_data], axis=0)
    data = data[:, :, [28, 29, 36, 37]]
elif args.dataset in ["PEMS-SF"]:
    data = pd.read_csv(f"data/{args.dataset}.csv").values[ : , 100 : 109]
elif args.dataset in ["weather"]:
    data = pd.read_csv(f"data/{args.dataset}.csv")
    if args.task_type == "SSR":
        data = data[["p (mbar)", "T (degC)", "rh (%)", "VPact (mbar)", "wd (deg)", "Tlog (degC)"]].values
    elif args.task_type == "ASR":
        data = data[["p (mbar)", "T (degC)", "rho (g/m**3)", "wv (m/s)", "rain (mm)", "SWDR (W/m???)"]].values
else : raise ValueError(f"Dataset {args.dataset} not supported.")
