# -*- coding: utf-8 -*-


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
from scipy.stats import wishart
from sklearn import linear_model
import random


def calculate_sig_num(sig_degree, d):
    """
    Calculate the number of signatures to a truncated level.
    Input: 
        sig_degree: the maximum order of signatures
        d:          dimension of the underlying process
    Output:
        sig_num:    number of signatures truncated to order=sig_degree
    """
    
    if d>=2:
        sig_num = round(d*(d**sig_degree-1)/(d-1)) + 1
    else:
        sig_num = sig_degree +1
    
    return sig_num

def calculate_signature_path(X,Y,method="S"):
    """
    Calculate a single signature integral: \int X d Y.
    Input: 
        X:      X in "\int X d Y"
        Y:      Y in "\int X d Y"
        method: "S" or "Ito".  S=Stratonovich, Ito=Ito
    Output:
        sig:    \int X d Y
    """
    
    sig = [0]
    for i in range(len(X)-1):
        if method=="Ito":
            sig.append( sig[-1] + X[i] * (Y[i+1] - Y[i]) )
        else:
            sig.append( sig[-1] + (X[i+1]+X[i])/2 * (Y[i+1] - Y[i]) )
        
    return sig



def calculate_signature_to_K(X_matrix, K, method="S"):
    """
    Calculate all signatures truncated to order K.
    Input: 
        X_matrix:     underlying process X
        K:            highest order of signature
        method: "S" or "Ito".  S=Stratonovich, Ito=Ito
    Output:
        all signatures of X truncated to order K
    """
    
    (T, d) = np.shape(X_matrix)
    
    if d>=2:
        signature_res = np.zeros((T, round(d*(d**K-1)/(d-1)) ))
    else:
        signature_res = np.zeros((T, K ))
    
    
    for i in range(d):
        signature_res[:,i] = X_matrix[:,i] - X_matrix[0,i]
    
    location = d
    
    for k in range(1,K):
        
        if d >=2:
            first = round(d*(d**(k-1)-1)/(d-1))
            second = round(d*(d**k-1)/(d-1))
        else:
            first = k-1
            second = k
        
        for i in range(  first , second ):
            for j in range(d):
                signature_res[:,location+j] = calculate_signature_path(signature_res[:,i], X_matrix[:,j], method)
            location = location + d
    
    return np.array([1]+list(signature_res[-1,:]))




def generate_index_list(T, d, K):
    """
    Generate the index list using recursive order, begin with 1.
    Input: 
        T:     time length
        d:     dimension of the underlying process
        K:     highest order of signature
    Output:
        the index list of all signatures begin with 1
    """
    
    index_list = []
    
    for i in range(d):
        index_list.append((i+1,))
    
    location = d
    
    for k in range(1,K):
        if d >=2:
            first = round(d*(d**(k-1)-1)/(d-1))
            second = round(d*(d**k-1)/(d-1))
        else:
            first = k-1
            second = k
        
        for i in range(  first , second ):
            for j in range(d):
                index_list.append( index_list[i] + (j+1,) )
            location = location + d

    return [str(i) for i in [()]+index_list]


def generate_index_list_begin_zero(T, d, K):
    """
    Generate the index list using recursive order, begin with 0.
    Input: 
        T:     time length
        d:     dimension of the underlying process
        K:     highest order of signature
    Output:
        the index list of all signatures begin with 0
    """
    index_list = []
    
    for i in range(d):
        index_list.append((i,))
    
    location = d
    
    for k in range(1,K):
        if d >=2:
            first = round(d*(d**(k-1)-1)/(d-1))
            second = round(d*(d**k-1)/(d-1))
        else:
            first = k-1
            second = k
        
        for i in range(  first , second ):
            for j in range(d):
                index_list.append( index_list[i] + (j,) )
            location = location + d

    return [str(i) for i in [()]+index_list]


def calculate_corr_by_formula(Omega, T, d, sig_degree, method="Ito"):
    """
    Calculate the theoretical correlation matrix for signatures of Brownian motion using formulas in the paper. 
    If considering Stratonovich signatures, only orders less than 4 can be calculated. 
    Input: 
        Omega:      inter-dimensional correlation matrix
        T:          time length
        d:          dimension of the underlying process
        sig_degree: the maximum order of signatures
        method:     "S" or "Ito".  S=Stratonovich, Ito=Ito
    Output:
        the theoretical correlation matrix for signatures of Brownian motion
    """
    
    index_list = generate_index_list_begin_zero(T, d, sig_degree)
    
    if method == "Ito":
        corr_matrix_by_formula_I = pd.DataFrame( 1, index=index_list, columns=index_list )
        
        for ind_m_ in index_list:
            for ind_n_ in index_list:
                ind_m = eval(ind_m_)
                ind_n = eval(ind_n_)
            
                if len(ind_m) != len(ind_n):
                    res = 0
                else:
                    res = 1
                    for i in range(len(ind_m)):
                        res = res * Omega[ind_m[i], ind_n[i]]
                        
                corr_matrix_by_formula_I.loc[ind_m_,ind_n_] = res
        
        corr_matrix_by_formula_I = corr_matrix_by_formula_I.values
        return corr_matrix_by_formula_I
    
    else:
        second_mom_by_formula_S = pd.DataFrame( 1, index=index_list, columns=index_list )
        
        for ind_m_ in index_list:
            for ind_n_ in index_list:
                ind_m = eval(ind_m_)
                ind_n = eval(ind_n_)
            
                if len(ind_m) % 2 == 0 and len(ind_n) % 2 == 1:
                    res = 0
                elif len(ind_m) % 2 == 1 and len(ind_n) % 2 == 0:
                    res = 0
                elif len(ind_m) == 0 and len(ind_n) == 0:
                    res = 1
                    
                elif len(ind_m) == 0 and len(ind_n) == 2:
                    res = Omega[ind_n[0], ind_n[1]]/2 * T 
                elif len(ind_m) == 2 and len(ind_n) == 0:
                    res = Omega[ind_m[0], ind_m[1]]/2 * T
                elif len(ind_m) == 0 and len(ind_n) == 4:
                    res = Omega[ind_n[0], ind_n[1]]/2 * Omega[ind_n[2], ind_n[3]]/2 * T**2 / 2
                elif len(ind_m) == 4 and len(ind_n) == 0:
                    res = Omega[ind_m[0], ind_m[1]]/2 * Omega[ind_m[2], ind_m[3]]/2 * T**2 / 2
                
                
                elif len(ind_m) == 1 and len(ind_n) == 1:
                    res = Omega[ind_m[0], ind_n[0]] * T
                elif len(ind_m) == 1 and len(ind_n) == 3:
                    res = Omega[ind_m[0], ind_n[2]] * Omega[ind_n[0], ind_n[1]]/2 * T**2 / 2
                    res = res + Omega[ind_n[1], ind_n[2]]/2 * Omega[ind_m[0], ind_n[0]] * T**2 / 2
                elif len(ind_m) == 3 and len(ind_n) == 1:
                    res = Omega[ind_m[2], ind_n[0]] * Omega[ind_m[0], ind_m[1]]/2 * T**2 / 2
                    res = res +  Omega[ind_m[1], ind_m[2]]/2 * Omega[ind_m[0], ind_n[0]] * T**2 / 2
                elif len(ind_m) == 2 and len(ind_n) == 2:
                    res = Omega[ind_m[1], ind_n[1]] * Omega[ind_m[0], ind_n[0]] * T**2 / 2
                    res = res +  Omega[ind_n[0], ind_n[1]]/2 * Omega[ind_m[0], ind_m[1]]/2 * T**2
                        
                elif len(ind_m) == 2 and len(ind_n) == 4:
                    res = Omega[ind_m[1], ind_n[3]] * Omega[ind_m[0], ind_n[2]] * Omega[ind_n[0], ind_n[1]]/2 * T**3 / 6 
                    res = res +  Omega[ind_m[1], ind_n[3]] * Omega[ind_n[1], ind_n[2]]/2 * Omega[ind_m[0], ind_n[0]] * T**3 / 6 
                    res = res +  Omega[ind_n[2], ind_n[3]]/2 * Omega[ind_m[1], ind_n[1]] * Omega[ind_m[0], ind_n[0]] * T**3 / 6 
                    res = res +  Omega[ind_n[2], ind_n[3]]/2 * Omega[ind_n[0], ind_n[1]]/2 * Omega[ind_m[0], ind_m[1]]/2 * T**3 / 2
                        
                elif len(ind_m) == 4 and len(ind_n) == 2:
                    res = Omega[ind_n[1], ind_m[3]] * Omega[ind_n[0], ind_m[2]] * Omega[ind_m[0], ind_m[1]]/2 * T**3 / 6 
                    res = res +  Omega[ind_n[1], ind_m[3]] * Omega[ind_m[1], ind_m[2]]/2 * Omega[ind_n[0], ind_m[0]] * T**3 / 6 
                    res = res +  Omega[ind_m[2], ind_m[3]]/2 * Omega[ind_n[1], ind_m[1]] * Omega[ind_n[0], ind_m[0]] * T**3 / 6 
                    res = res +  Omega[ind_m[2], ind_m[3]]/2 * Omega[ind_m[0], ind_m[1]]/2 * Omega[ind_n[0], ind_n[1]]/2 * T**3 / 2
                        
                elif len(ind_m) == 3 and len(ind_n) == 3:
                    res = Omega[ind_m[2], ind_n[2]] * Omega[ind_m[1], ind_n[1]] * Omega[ind_m[0], ind_n[0]] * T**3 / 6 
                    res = res +  Omega[ind_m[2], ind_n[2]] * Omega[ind_n[0], ind_n[1]]/2 * Omega[ind_m[0], ind_m[1]]/2 * T**3 / 3 
                    res = res +  Omega[ind_m[1], ind_m[2]]/2 * Omega[ind_m[0], ind_n[2]] * Omega[ind_n[0], ind_n[1]]/2 * T**3 / 6 
                    res = res +  Omega[ind_n[1], ind_n[2]]/2 * Omega[ind_m[2], ind_n[0]] * Omega[ind_m[0], ind_m[1]]/2 * T**3 / 6
                    res = res +  Omega[ind_n[1], ind_n[2]]/2 * Omega[ind_m[1], ind_m[2]]/2 * Omega[ind_m[0], ind_n[0]] * T**3 / 3 
                    
                elif len(ind_m) == 4 and len(ind_n) == 4:
                    res = Omega[ind_m[3], ind_n[3]] * Omega[ind_m[2], ind_n[2]] * Omega[ind_m[1], ind_n[1]] * Omega[ind_m[0], ind_n[0]] * T**4 / 24
                    res = res +  Omega[ind_m[3], ind_n[3]] * Omega[ind_m[2], ind_n[2]] * Omega[ind_m[0], ind_m[1]]/2 * Omega[ind_n[0], ind_n[1]]/2 * T**4 / 12
                    res = res +  Omega[ind_m[3], ind_n[3]] * Omega[ind_m[1], ind_m[2]]/2 * Omega[ind_m[0], ind_n[2]] * Omega[ind_n[0], ind_n[1]]/2 * T**4 / 24
                    res = res +  Omega[ind_m[3], ind_n[3]] * Omega[ind_n[1], ind_n[2]]/2 * Omega[ind_m[2], ind_n[0]] * Omega[ind_m[0], ind_m[1]]/2 * T**4 / 24
                    res = res +  Omega[ind_m[3], ind_n[3]] * Omega[ind_n[1], ind_n[2]]/2 * Omega[ind_m[1], ind_m[2]]/2 * Omega[ind_m[0], ind_n[0]] * T**4 / 12
                    res = res +  Omega[ind_m[2], ind_m[3]]/2 * Omega[ind_m[1], ind_n[3]] * Omega[ind_m[0], ind_n[2]] * Omega[ind_n[0], ind_n[1]]/2 * T**4 / 24
                    res = res +  Omega[ind_m[2], ind_m[3]]/2 * Omega[ind_m[1], ind_n[3]] * Omega[ind_n[1], ind_n[2]]/2 * Omega[ind_m[0], ind_n[0]] * T**4 / 24
                    res = res +  Omega[ind_n[2], ind_n[3]]/2 * Omega[ind_m[3], ind_n[1]] * Omega[ind_m[2], ind_n[0]] * Omega[ind_m[0], ind_m[1]]/2 * T**4 / 24
                    res = res +  Omega[ind_n[2], ind_n[3]]/2 * Omega[ind_m[3], ind_n[1]] * Omega[ind_m[1], ind_m[2]]/2 * Omega[ind_m[0], ind_n[0]] * T**4 / 24
                    res = res +  Omega[ind_n[2], ind_n[3]]/2 * Omega[ind_m[2], ind_m[3]]/2 * Omega[ind_m[1], ind_n[1]] * Omega[ind_m[0], ind_n[0]] * T**4 / 12
                    res = res +  Omega[ind_n[2], ind_n[3]]/2 * Omega[ind_m[2], ind_m[3]]/2 * Omega[ind_m[0], ind_m[1]]/2 * Omega[ind_n[0], ind_n[1]]/2 * T**4 / 4
                    
                        
                second_mom_by_formula_S.loc[ind_m_,ind_n_] = res
                
        diag_matrix_S = np.linalg.inv( np.sqrt( np.diag(np.diag(second_mom_by_formula_S.values) )))
        corr_matrix_by_formula_S = diag_matrix_S @ second_mom_by_formula_S.values @ diag_matrix_S

        return corr_matrix_by_formula_S


def calculate_sample_corr_matrix(signature_list):
    """
    Calculate the sample correlation matrix for signatures.
    Input: 
        signature_list: samples of signatures
    Output:
        the sample correlation matrix for signatures
    """
    
    X_list = np.array(signature_list)
    X_list = X_list / np.sqrt(np.sum(X_list**2, axis=0)) 
    return X_list.T @ X_list



def calculate_irrepresent_condition(corr_matrix, A_index, sign_A):
    """
    Calculate the vector of the irrepresentable condition. 
    Input: 
        corr_matrix:   correlation matrix of all predictors
        A_index:       index for true predictors, A* in the paper
        sign_A:        sign of beta coefficients of true predictors
    Output:
        the vector of the irrepresentable condition
    """
    
    sig_num = np.shape(corr_matrix)[0]
    Ac_index = [i for i in range(sig_num) if i not in A_index]
    
    Sigma_Ac_A = corr_matrix[Ac_index,:][:, A_index]
    Sigma_A_A = corr_matrix[A_index,:][:, A_index]
    inv_matrix = np.linalg.inv(Sigma_A_A)
    
    res = Sigma_Ac_A @ inv_matrix @ sign_A
    
    return res






def check_Lasso_select_result(coefs, beta_location, beta_values):
    """
    Check whether the Lasso is consistent.
    Input: 
        coefs:         coefficients estimated by sklearn.linear_model.lars_path package
        beta_location: location of true predictors in the true model
        beta_values:   values of true predictors in the true model
    Output:
        Whether the Lasso is consistent or not
    """
    
    beta_location_set = set(beta_location)
    (factor_num, knot_num) = np.shape(coefs)
    
    select_result = False
    for i in range(knot_num-1):
        estimate_beta_location = np.nonzero(coefs[:,i] + coefs[:,i+1])[0]
        if set(estimate_beta_location)==beta_location_set:
            
            flag = 1
            for loc in estimate_beta_location:
                original_loc = beta_location.index(loc)
                if beta_values[original_loc] * (coefs[loc, i] + coefs[loc, i+1]) < 0:
                    flag = 0
                    break
            
            if flag==1:            
                select_result = True
                break
    
    return select_result

    
    
def check_Lasso_confusion_matrix(coefs, beta_location, beta_values):
    """
    Calculate the maximum precision, recall, and F1-score defined in Appendix D.
    Input: 
        coefs:         coefficients estimated by sklearn.linear_model.lars_path package
        beta_location: location of true predictors in the true model
        beta_values:   values of true predictors in the true model
    Output:
        maximum precision, maximum recall, and maximum F1-score
    """
    
    beta_location_set = set(beta_location)
    (factor_num, knot_num) = np.shape(coefs)
    nonbeta_location_set = set(range(factor_num)) - beta_location_set
    
    max_precision = 0
    max_recall = 0
    max_F1 = 0
    for i in range(knot_num-1):
        estimate_beta_location = np.nonzero(coefs[:,i] + coefs[:,i+1])[0]
        estimate_beta_location = set(estimate_beta_location)
        
        estimate_nonbeta_location = set(range(factor_num)) - estimate_beta_location
        
        TP = len(estimate_beta_location.intersection(beta_location_set))
        FN = len(estimate_nonbeta_location.intersection(beta_location_set))
        FP = len(estimate_beta_location.intersection(nonbeta_location_set))
#        TN = len(estimate_nonbeta_location.intersection(nonbeta_location_set))
        
        precision = TP / (TP+FP)
        recall = TP / (TP+FN)
        if precision!=0 and recall!=0:
            F1 = 2 / (1/precision + 1/recall)
        else:
            F1 = 0
        
        if precision > max_precision:
            max_precision = precision
        if recall > max_recall:
            max_recall = recall
        if F1 > max_F1:
            max_F1 = F1
            
    return max_precision, max_recall, max_F1





