import experiments.data.data_generation
import experiments.utils
from sklearn.preprocessing import normalize, StandardScaler
import numpy as np
import pandas as pd
import csv
import torch

def get_data(args, n, d, T=10, dataset="synthetic", dataset_id=0, filename_data=None, filename_gt=None):
    if (dataset in ["time_series", "time_series_laplace", "laplace", "synthetic"]):
        (a, b) = tuple(args.weight_bounds)
        k = args.edges

        # Initiating the random DAG
        average_degrees_per_lagged_node = [2 for _ in range(args.number_of_lags)]
        B_true = experiments.data.data_generation.simulate_time_unrolled_dag(d, k * d, args.graph_type, args.number_of_lags, average_degrees_per_lagged_node) # random graph simulation with avg degree = k

        # Initializing weights on the adjacency matrix
        W_true = experiments.data.data_generation.simulate_parameter(np.array(B_true), w_ranges=((-b, -a), (a, b))) # sampling uniformly the weights            
        W_true = list(W_true)

        X, C_true = experiments.data.data_generation.sparse_input_sem(W_true, T, n=n, sparsity=args.sparsity, std=args.noise_std,
                                    noise_type=args.noise, sparsity_type=args.sparsity_type)

        cond_num = 0
        W_true = np.concatenate(W_true, axis=1)
        B_true = np.concatenate(B_true, axis=1)

        return X, C_true, cond_num, B_true, W_true
    

    elif (dataset == "finance"):
        df = pd.read_csv('experiments/data/FinanceCPT/returns/{}'.format(filename_data), sep=',', header=None)
        X = df.to_numpy()
        d = X.shape[-1]
        B_true = experiments.utils.edges_to_adjacency('experiments/data/FinanceCPT/relationships/{}'.format(filename_gt), d=d, time_lag=args.number_of_lags)
        return X, 0, 0, B_true, 0


    elif (dataset == "S&P"):
        df = pd.read_csv('experiments/data/S&P500/{}.csv'.format(filename_data), sep=',')
        print(df.head())
        date = df["Date"]
        date.to_csv("experiments/data/S&P500/Dates.csv", index=False)
        df = df.drop(["Date"], axis=1)
        
        X = df.to_numpy()
        logreturns = True # whether to use the values themselves or the returns.
        print(X.shape)
        if logreturns:
            X_logreturns = np.log(X[1:,:] / X[:-1,:]) #/ X[1:,:] # [P_i(t) - P_i(t - 1)] / P_i(t - 1)
            print(X_logreturns.shape)

        df = pd.read_csv('experiments/data/S&P500/Dividends.csv'.format(filename_data), sep=',')
        print(df.head())
        df = df.drop(["Date"], axis=1)
        D = df.to_numpy()
        
        return X_logreturns, X, D, 0, 0
    
    ########################## code from eSRU Khanna and Tan 2020 ICLR
    elif (dataset == "dream3"):
        Xtrain, Gref = getGeneTrainingData(dataset_id)
        n1 = Xtrain.shape[1]
        if(n != n1):
            print("Error::Dimension mismatch for input training data..")
        numTotalSamples = Xtrain.shape[1]

        # Make input signal zero mean and appropriately scaled
        Xtrain = Xtrain - Xtrain.mean()  
        inputSignalMultiplier = 50
        Xtrain = inputSignalMultiplier * Xtrain

        return Xtrain, 0, 0, Gref, Gref
    
    else:
        print("Dataset {} not found".format(dataset))

########################## code from eSRU Khanna and Tan 2020 ICLR
def getGeneTrainingData(dataset_id):

    if(dataset_id == 1):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Ecoli1.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Ecoli1.tsv"
    elif(dataset_id == 2):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Ecoli2.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Ecoli2.tsv"
    elif(dataset_id == 3):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Yeast1.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Yeast1.tsv"
    elif(dataset_id == 4):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Yeast2.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Yeast2.tsv"
    elif(dataset_id == 5):
        InputDataFilePath = "experiments/data/dream3/Dream3TensorData/Size100Yeast3.pt"
        RefNetworkFilePath = "experiments/data/dream3/TrueGeneNetworks/InSilicoSize100-Yeast3.tsv"
    else:
        print("Error while loading gene training data")    

    Xtrain = loadTrainingData(InputDataFilePath)
    n = Xtrain.shape[1]
    Gref = loadTrueNetwork(RefNetworkFilePath, n)   
    
    return Xtrain, Gref

########################## code from eSRU Khanna and Tan 2020 ICLR
######################################
# Function for loading input data 
######################################
def loadTrainingData(inputDataFilePath):

    # Load and parse input data (create batch data)
    inpData = torch.load(inputDataFilePath)
    Xtrain = inpData['TsData']

    return Xtrain

#######################################################
# Function for reading ground truth network from file 
#######################################################
def loadTrueNetwork(inputFilePath, networkSize):

    with open(inputFilePath) as tsvin:
        reader = csv.reader(tsvin, delimiter='\t')
        numrows = 0    
        for row in reader:
            numrows = numrows + 1

    network = np.zeros((numrows,2),dtype=np.int16)
    with open(inputFilePath) as tsvin:
        reader = csv.reader(tsvin, delimiter='\t')
        rowcounter = 0
        for row in reader:
            network[rowcounter][0] = int(row[0][1:])
            network[rowcounter][1] = int(row[1][1:])
            rowcounter = rowcounter + 1 

    Gtrue = np.zeros((networkSize,networkSize), dtype=np.int16)
    for row in range(0,len(network),1):
        Gtrue[network[row][1]-1][network[row][0]-1] = 1   
    
    return Gtrue.T


def data_transform(X, args):
    # applying transformation to data (or not)
    if (args.transformation == 'norm'):
        X = normalize(X)
    elif (args.transformation == 'stand'):
        scaler = StandardScaler().fit(X)
        X = scaler.transform(X)
    return X
    