import os
import sys
import time
import torch
import numpy as np
import argparse
from collections import defaultdict

from sklearn.model_selection import ShuffleSplit
from sktime.dists_kernels import SignatureKernel

from sklearn.svm import SVR
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import MinMaxScaler
from collections import defaultdict

from permetrics import RegressionMetric


def setup_cmdline_parsing():
    generic_parser = argparse.ArgumentParser()
    group0 = generic_parser.add_argument_group('Data loading/saving arguments')
    group0.add_argument("--vecs-inp-file", type=str, default="simu.pt")
    group0.add_argument("--prms-inp-file", type=str, default="prms.pt")
    group0.add_argument("--kern-out-base", type=str, default="/tmp/kern")
    group0.add_argument("--stat-out-file", type=str, default="/tmp/stat.pt")

    group1 = generic_parser.add_argument_group('Configuration')
    group1.add_argument("--level", type=int, default=2)
    #group1.add_argument('--param-ids', type=int, nargs='+')
    return generic_parser


def lag_transform(X, lags=[0,1,2]):
    """Lag transform of the time series."""
    lagged_X = defaultdict()
    N,T,D = X.shape
    for i in range(len(lags)):
        Y = torch.zeros(N,T,D*(i+1))
        for j,l in enumerate(lags[:i+1]):
            Y[:,l:,j*D:(j+1)*D] = X[:,0:T-l,:]
        lagged_X[lags[i]] = Y
    return lagged_X


def time_subsample(vecs, sample_rate=0.2):
    N,D,T = vecs.shape
    num_keep = int(T*sample_rate)
    vecs_ss = torch.zeros(N,D,num_keep)
    for i in range(vecs.shape[0]):
        vecs_ss[i] = vecs[i,:,torch.randperm(T)[0:num_keep]]
    return vecs_ss


def run_regression(K, lags, vecs, y, C_s, e_s, n_splits=10, test_size=.20, id=0):

    rmses, r2vals, smapes  = [], [], []
    cv = ShuffleSplit(n_splits=n_splits, test_size=test_size)
    metric = RegressionMetric()

    # outer loop
    for i, (trn_idx, tst_idx) in enumerate(cv.split(vecs)):

        best_score = 1e9
        best_C, best_e, best_l = 0, 0, 0
        svr = SVR(kernel='precomputed')
            
        for l in lags:
            K_trn = K[l][trn_idx,:][:,trn_idx]
            K_tst = K[l][tst_idx,:][:,trn_idx]  
            y_trn = y[trn_idx,id]
            y_tst = y[tst_idx,id]

            for C in C_s:
                for e in e_s:
                    svr.C = C
                    svr.epsilon = e
                    inner_cv = ShuffleSplit(n_splits=1, test_size=.20)
                    
                    # only split once in inner loop
                    trn_idx_cv, tst_idx_cv = next(inner_cv.split(range(K_trn.shape[0])))                
                    K_trn_cv = K_trn[trn_idx_cv,:][:,trn_idx_cv]
                    K_tst_cv = K_trn[tst_idx_cv,:][:,trn_idx_cv]
                    y_trn_cv = y_trn[trn_idx_cv]
                    y_tst_cv = y_trn[tst_idx_cv]
                    
                    svr.fit(K_trn_cv, y_trn_cv)
                    y_hat_cv = svr.predict(K_tst_cv)
                    score = mean_squared_error(y_tst_cv, y_hat_cv)
                    
                    if score < best_score:
                        best_C = C # C-param of SVM
                        best_e = e # epsilon-param of SVM
                        best_l = l # lag
                        best_score = score
        
        svr.C = best_C
        svr.epsilon = best_e
        K_trn = K[best_l][trn_idx,:][:,trn_idx]
        K_tst = K[best_l][tst_idx,:][:,trn_idx]
        y_trn = y[trn_idx,id]
        y_tst = y[tst_idx,id]
        svr.fit(K_trn, y_trn)
        y_hat = svr.predict(K_tst)
        
        rmses.append(metric.root_mean_squared_error(y_tst.numpy(), y_hat))
        smapes.append(metric.symmetric_mean_absolute_percentage_error(y_tst.numpy(), y_hat))
        r2vals.append(r2_score(y_tst.numpy(), y_hat))

    return rmses, r2vals, smapes


def main():
    
    parser = setup_cmdline_parsing()
    args = parser.parse_args()
    print(args)
    
    prms = torch.load(args.prms_inp_file)
    vecs = torch.load(args.vecs_inp_file)
    prms_ids = list(range(prms.shape[1]))
    
    vecs = vecs.permute(0,2,1)
    N,D,T = vecs.shape
    print(f'{N} time series of dim {D} with {T} timepoints!')
    
    data = vecs.permute(0,2,1).view(-1,T)
    scaler = MinMaxScaler()
    scaler.fit(data)
    vecs = torch.tensor(scaler.transform(data)).view(N,T,D)
    
    lags = [0,1,2]
    lagged_vecs = lag_transform(vecs)

    K_ss = defaultdict(list)
    for l in lags:
        kern_out_file = args.kern_out_base + "_level_{}_lag_{}.pt".format(args.level,l)
        if os.path.exists(kern_out_file):
            print('Loading {}'.format(kern_out_file))
            K_ss[l] = torch.load(kern_out_file)
        else:
            t0=time.time()
            sk = SignatureKernel(normalize=True, level=args.level)
            K_ss[l] = sk.transform(lagged_vecs[l].permute(0,2,1).numpy())
            print('Computed {} in {} sec'.format(kern_out_file, time.time()-t0))
            torch.save(K_ss[l], kern_out_file)

    C_s = np.logspace(-3, 1, 5) # from paper
    e_s = np.logspace(-4, 1, 5) # from paper

    stats = defaultdict(list)
    cv_runs = 10
    for aux_d in prms_ids:
        rmses_ss, r2vals_ss, smapes_ss = run_regression(K_ss, lags, vecs, prms, C_s, e_s, cv_runs, 0.2, aux_d)
        print('[{}]: RMSE={:0.4f} +/- {:0.4f} | R2={:0.4f} +/- {:0.4f} | SMAPE={:0.4f} +/- {:0.4f} '.format(
            aux_d,
            np.mean(rmses_ss),
            np.std(rmses_ss),
            np.mean(r2vals_ss),
            np.std(r2vals_ss),
            np.mean(smapes_ss),
            np.std(smapes_ss),
        ))            
    
        [stats['r2s_param'+str(aux_d)].append(tmp) for tmp in r2vals_ss] # r2 per CV run and parameter
        [stats['rmse_param'+str(aux_d)].append(tmp) for tmp in rmses_ss] # rmse per CV run and parameter
        [stats['smape_param'+str(aux_d)].append(tmp) for tmp in smapes_ss] # smape per CV run and parameter
    
    torch.save(stats, args.stat_out_file)
    

if __name__ == "__main__":
    main()






