import experiments.data.data_generation
import experiments.utils
from sklearn.preprocessing import normalize, StandardScaler
import numpy as np
import pandas as pd
import csv
import torch

def get_data(args, n, d, T=10, dataset="synthetic", dataset_id=0, filename_data=None, filename_gt=None):
    if (dataset in ["time_series", "synthetic", "cyclic", "dynamic"]):
        (a, b) = tuple(args.weight_bounds)
        k = args.edges

        if dataset == "time_series":
            # Initiating the random DAG
            average_degrees_per_lagged_node = [2 for _ in range(args.number_of_lags)]
            B_true = experiments.data.data_generation.simulate_time_unrolled_dag(d, k * d, args.graph_type, args.number_of_lags, average_degrees_per_lagged_node) # random graph simulation with avg degree = k

            # Initializing weights on the adjacency matrix
            W_true = experiments.data.data_generation.simulate_parameter(np.array(B_true), w_ranges=((-b, -a), (a, b))) # sampling uniformly the weights
            W_true = list(W_true)

            X, C_true = experiments.data.data_generation.sparse_rct_sem(W_true, T, n=n, sparsity=args.sparsity, std=args.noise_std,
                                      noise_type="gauss", noise_effect=args.noise_effect)

            cond_num = 0
            W_true = np.concatenate(W_true, axis=1)
            B_true = np.concatenate(B_true, axis=1)

        elif dataset == "dynamic": # dynamic dataset 2-block
            # B = data.data_generation.simulate_digraph(d, k * d, args.graph_type)
            B_true = experiments.data.data_generation.simulate_graph(d, k * d, args.graph_type, selfloops=True)
            W_true = experiments.data.data_generation.simulate_parameter(B_true, w_ranges=((-b, -a), (a, b))) # sampling uniformly the weights

            # Generating block matrices
            B, W = experiments.data.data_generation.block_matrices(B_true, W_true, T)

            # data initialization spectral SEM: X = C(I + \bar(W))
            X, C_true, cond_num = experiments.data.data_generation.sparse_spectral_sem(W, n, sparsity=args.sparsity, std=args.noise_std, 
                                                        noise_type=args.noise, noise_effect=args.noise_effect, trans_clos=args.trans_clos, fix_sup=args.fixSup)
            X, C_true = X.reshape((n * T, d)), C_true.reshape((n * T, d))

        elif dataset == "synthetic":
            B_true = experiments.data.data_generation.simulate_dag(d, k * d, args.graph_type) # Erdös-Renyi graph simulation with avg degree = k
            W_true = experiments.data.data_generation.simulate_parameter(B_true, w_ranges=((-b, -a), (a, b))) # sampling uniformly the weights
            X, C_true, cond_num = experiments.data.data_generation.sparse_spectral_sem(W_true, n, sparsity=args.sparsity, std=args.noise_std, 
                                                        noise_type=args.noise, noise_effect=args.noise_effect, trans_clos=args.trans_clos, fix_sup=args.fixSup)

        elif (dataset == "cyclic"):
            B_true = experiments.data.data_generation.simulate_digraph(d, k * d, args.graph_type) # Erdös-Renyi graph simulation with avg degree = k
            W_true = experiments.data.data_generation.simulate_convergent_parameter(B_true, w_ranges=(a, b)) # sampling uniformly the weights
            X, C_true, cond_num = experiments.data.data_generation.sparse_spectral_sem(W_true, n, sparsity=args.sparsity, std=args.noise_std, 
                                                        noise_type=args.noise, noise_effect=args.noise_effect, trans_clos=args.trans_clos, fix_sup=args.fixSup)

        return X, C_true, cond_num, B_true, W_true

    elif (dataset == "sachs"):
        X = np.load('data/sachs/data1.npy')
        B_true = np.load('data/sachs/DAG1.npy')
        print(B_true)

        print(X.shape)

        return X, 0, 0, B_true, 0
    
    elif (dataset == "thames"):
        df = pd.read_csv('data/thames/thames_data.csv', sep=',', header=None)
        signals = df.to_numpy()
        signals = signals[:, :13 * 50].flatten()

        X = np.zeros((7 * 50 ,13))    

        for i in range(7 * 50):
            X[i, :] = signals[i * 13: (i + 1) * 13] 

        x = np.array([1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12]) - 1
        y = np.array([6, 6, 6, 6, 6, 7, 13, 13, 13, 13, 13, 13]) - 1

        B_true = np.zeros((13, 13))
        B_true[x, y] = 1
    
    elif (dataset == "us_temps"):
        df = pd.read_csv('experiments/data/us_temps/US_max_temps.csv', sep=',')
        X = df.to_numpy()
        return X, 0, 0, 0, 0
    
    elif (dataset == "swiss_temps"):
        df = pd.read_csv('experiments/data/swiss_meteo/{}.csv'.format(filename_data), sep=',')
        df = df.drop(["Date"], axis=1)
        X = df.to_numpy()
        DX = X[1:, :]  - X[:-1, :]
        changes = False
        if changes:
            return DX, 0, 0, 0, 0
        
        return X, 0, 0, 0, 0
    

    elif (dataset == "finance"):
        df = pd.read_csv('experiments/data/FinanceCPT/returns/{}'.format(filename_data), sep=',', header=None)
        X = df.to_numpy()
        d = X.shape[-1]
        B_true = experiments.utils.edges_to_adjacency('experiments/data/FinanceCPT/relationships/{}'.format(filename_gt), d=d, time_lag=args.number_of_lags)
        return X, 0, 0, B_true, 0


    elif (dataset == "S&P"):
        df = pd.read_csv('experiments/data/S&P500/{}.csv'.format(filename_data), sep=',')
        print(df.head())
        date = df["Date"]
        date.to_csv("experiments/data/S&P500/Dates.csv", index=False)
        df = df.drop(["Date"], axis=1)
        
        X = df.to_numpy()
        logreturns = True # whether to use the values themselves or the returns.
        print(X.shape)
        if logreturns:
            X_logreturns = np.log(X[1:,:] / X[:-1,:]) #/ X[1:,:] # [P_i(t) - P_i(t - 1)] / P_i(t - 1)
            print(X_logreturns.shape)

        df = pd.read_csv('experiments/data/S&P500/Dividends.csv'.format(filename_data), sep=',')
        print(df.head())
        df = df.drop(["Date"], axis=1)
        D = df.to_numpy()
        
        return X_logreturns, X, D, 0, 0
    
    ########################## code from eSRU Khanna and Tan 2020 ICLR
    elif (dataset == "dream3"):
        Xtrain, Gref = getGeneTrainingData(dataset_id)
        n1 = Xtrain.shape[1]
        if(n != n1):
            print("Error::Dimension mismatch for input training data..")
        numTotalSamples = Xtrain.shape[1]

        # Make input signal zero mean and appropriately scaled
        Xtrain = Xtrain - Xtrain.mean()  
        inputSignalMultiplier = 50
        Xtrain = inputSignalMultiplier * Xtrain

        return Xtrain, 0, 0, Gref, Gref

########################## code from eSRU Khanna and Tan 2020 ICLR
def getGeneTrainingData(dataset_id):

    if(dataset_id == 1):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Ecoli1.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Ecoli1.tsv"
    elif(dataset_id == 2):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Ecoli2.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Ecoli2.tsv"
    elif(dataset_id == 3):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Yeast1.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Yeast1.tsv"
    elif(dataset_id == 4):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Yeast2.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Yeast2.tsv"
    elif(dataset_id == 5):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Yeast3.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Yeast3.tsv"
    else:
        print("Error while loading gene training data")    

    Xtrain = loadTrainingData(InputDataFilePath)
    n = Xtrain.shape[1]
    Gref = loadTrueNetwork(RefNetworkFilePath, n)   
    
    return Xtrain, Gref

########################## code from eSRU Khanna and Tan 2020 ICLR
######################################
# Function for loading input data 
######################################
def loadTrainingData(inputDataFilePath):

    # Load and parse input data (create batch data)
    inpData = torch.load(inputDataFilePath)
    Xtrain = inpData['TsData']

    return Xtrain

#######################################################
# Function for reading ground truth network from file 
#######################################################
def loadTrueNetwork(inputFilePath, networkSize):

    with open(inputFilePath) as tsvin:
        reader = csv.reader(tsvin, delimiter='\t')
        numrows = 0    
        for row in reader:
            numrows = numrows + 1

    network = np.zeros((numrows,2),dtype=np.int16)
    with open(inputFilePath) as tsvin:
        reader = csv.reader(tsvin, delimiter='\t')
        rowcounter = 0
        for row in reader:
            network[rowcounter][0] = int(row[0][1:])
            network[rowcounter][1] = int(row[1][1:])
            rowcounter = rowcounter + 1 

    Gtrue = np.zeros((networkSize,networkSize), dtype=np.int16)
    for row in range(0,len(network),1):
        Gtrue[network[row][1]-1][network[row][0]-1] = 1   
    
    return Gtrue.T


def data_transform(X, args):
    # applying transformation to data (or not)
    if (args.transformation == 'norm'):
        X = normalize(X)
    elif (args.transformation == 'stand'):
        scaler = StandardScaler().fit(X)
        X = scaler.transform(X)
    return X
    