import pandas as pd
import torch
import numpy as np
from torch.utils.data import DataLoader
import time
import Code.HSCIC as hs


# Simulation functions for X and Y, depending on scenario
def function_x(A, Z):
    return np.exp(-0.5 * A * A) * np.sin(2 * A) + 2 * Z


def function_y(A, Z, X):
    return np.sin(2*X*Z)*np.exp(-0.5*Z*X) + 5*A


# Generate n_samples data-points using function_x(A, Z) and function_y(A, Z, X)
def simulation_data(n_samples, scenario=1):
    Z = np.random.normal(0, 1, size=n_samples)
    A = Z * Z + np.random.normal(0, 1, size=n_samples)
    noise_x = np.random.normal(0, 0.1, size=n_samples)
    noise_y = np.random.normal(0, 0.1, size=n_samples)
    X = function_x(A, Z) + 0.2*noise_x
    Y = function_y(A, Z, X) + 0.2*noise_y
    X = np.concatenate((X.reshape(-1, 1), Z.reshape(-1, 1)), 1)
    df = pd.DataFrame(np.concatenate((A.reshape(-1, 1), X, Y.reshape(-1, 1), noise_x.reshape(-1, 1), noise_y.reshape(-1, 1)), axis=1),
                      columns=['A', 'X', 'Z', 'Y', 'u_x', 'u_y'])
    return df


# Processing data: normalize, get train and test tensors
def data_processing(df, batch_size):
    # normalize
    norm_A = (df['A'] - np.mean(df['A'])) / np.var(df['A'])
    norm_XZ = (df[['X', 'Z']] - np.mean(df[['X', 'Z']], axis=0)) / np.var(df[['X', 'Z']], axis=0)
    norm_Y = (df['Y'] - np.mean(df['Y'])) / np.var(df['Y'])

    # 80% of the sample are used for training, and 20% for testing
    rnd = np.random.uniform(0, 1, len(df))
    train_idx = np.array(np.where(rnd < 0.8)).flatten()
    test_idx = np.array(np.where(rnd >= 0.8)).flatten()

    # split normalized data in data_train and data_test
    data_train = np.concatenate(
        (np.array(norm_XZ.loc[train_idx]).reshape(norm_XZ.loc[train_idx].shape[0], -1),
         np.array(norm_A[train_idx]).reshape(-1, 1),  np.array(norm_Y[train_idx]).reshape(-1, 1)), 1)
    data_test = np.concatenate(
        (np.array(norm_XZ.loc[test_idx]).reshape(norm_XZ.loc[test_idx].shape[0], -1),
         np.array(norm_A[test_idx]).reshape(-1, 1), np.array(norm_Y[test_idx]).reshape(-1, 1)), 1)

    train_dataloader = DataLoader(data_train, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(data_test, batch_size=batch_size, shuffle=False)

    return train_dataloader, test_dataloader, norm_A, norm_XZ, norm_Y


# Select sample from observed data and generate n_count respective counterfactual samples
def generate_counterfactuals(n_count, meanA, varA, meanX, varX, meanY, varY, data):
    # select sample from observed data
    index = np.random.randint(low=0, high=len(data), size=1)
    Z = data.loc[index, 'Z']
    noise_x = data.loc[index, 'u_x']
    noise_y = data.loc[index, 'u_y']

    # generate n_count counterfactuals
    count_Z = np.random.normal(0, 1, size=n_count)
    count_A = count_Z*count_Z + np.random.normal(0, 1, size=n_count)
    count_X = function_x(count_A, np.full(n_count, Z)) + 0.2 * np.full(n_count, noise_x)
    count_Y = function_y(count_A, np.full(n_count, Z), count_X) + 0.2 * np.full(n_count, noise_y)

    # normalize counterfactuals
    count_A = (count_A - meanA) / varA
    count_X = np.concatenate((count_X.reshape(-1, 1), np.full(n_count, Z).reshape(-1, 1)), 1)
    count_X = (count_X - np.array(meanX).reshape(1, 2)) / np.array(varX).reshape(1, 2)
    count_Y = (count_Y - meanY) / varY

    return count_A, count_X, count_Y


# For n_samples data points, generate n_count counterfactuals with function generate_counterfactuals
# and store them in count_loader_tot
def counterfactual_simulations(n_samples, n_count, df, n_scenario):
    meanA, varA, meanXZ, varXZ, meanY, varY = np.mean(df['A']), np.var(df['A']), np.mean(df[['X', 'Z']], axis=0), \
                                              np.var(df[['X', 'Z']], axis=0), np.mean(df['Y']), np.var(df['Y'])
    count_loader_tot = {}
    # repeat counterfactuals generation for n_samples data points
    for i in range(n_samples):
        count_A, count_X, count_Y = generate_counterfactuals(n_count,
                                                             meanA, varA, meanXZ, varXZ, meanY, varY, df)
        vec_count = np.concatenate((count_X.reshape(count_X.shape[0], -1),
                                    count_A.reshape(-1, 1), count_Y.reshape(-1, 1)), 1)
        count_loader_tot[i] = DataLoader(vec_count, 1, shuffle=False)

    return count_loader_tot


# Train cnet model using as loss function: L = (accuracy loss) + beta * (hscic loss)
def train_model(train_dataloader, optimizer, cnet, loss_function, beta_hscic, num_epochs, test_dataloader):
    start_time = time.time()

    loss_train_vals = []
    loss_test_vals = []
    loss_hscic_train = []
    loss_hscic_test = []
    loss_acc_train = []
    loss_acc_test = []

    hscic = hs.HSCIC()
    for epoch in range(num_epochs):
        epoch_train_loss_tot = []
        epoch_test_loss_tot = []
        train_train_hscic_sum, test_hscic_loss_sum, train_train_acc_loss_sum, test_acc_loss_sum = [0, 0, 0, 0]

        for batch_idx, data in enumerate(train_dataloader, 0):
            inputs, outputs = torch.split(data, (train_dataloader.dataset[1].size-1, 1), 1)
            inputs = inputs.float()
            outputs = outputs.float()

            optimizer.zero_grad()

            y = cnet(inputs)
            loss = loss_function(y, outputs)
            hscic_value = hscic(y, inputs[:, 2], inputs[:, 1])  # hscic(Y, A, Z)
            loss_model = loss+beta_hscic*hscic_value

            epoch_train_loss_tot.append(loss_model.item())
            train_train_acc_loss_sum += loss.item()
            train_train_hscic_sum += hscic_value.item()

            loss_model.backward()
            optimizer.step()

        for data in test_dataloader:
            inputs, outputs = torch.split(data, (test_dataloader.dataset[1].size - 1, 1), 1)
            inputs = inputs.float()
            outputs = outputs.float()
            y_test = cnet(inputs)
            loss_test = loss_function(y_test, outputs)
            hscic_value_test = hscic(y_test, inputs[:, 2], inputs[:, 1])  # hscic(Y, A, Z)
            loss_model_test = loss_test + beta_hscic * hscic_value_test

            epoch_test_loss_tot.append(loss_model_test.item())
            test_acc_loss_sum += loss_test.item()
            test_hscic_loss_sum += hscic_value_test.item()

        loss_test_vals.append(sum(epoch_test_loss_tot) / len(epoch_test_loss_tot))
        loss_train_vals.append(sum(epoch_train_loss_tot) / len(epoch_train_loss_tot))
        loss_hscic_train.append(train_train_hscic_sum/len(epoch_train_loss_tot))
        loss_hscic_test.append(test_hscic_loss_sum / len(epoch_test_loss_tot))
        loss_acc_train.append(train_train_acc_loss_sum / len(epoch_train_loss_tot))
        loss_acc_test.append(test_acc_loss_sum / len(epoch_test_loss_tot))

    print('--- %s seconds ----' % (time.time()-start_time))
    print("Finished Training")
    return cnet, loss_train_vals, loss_acc_train, loss_hscic_train, loss_test_vals, loss_acc_test, loss_hscic_test


# Find the outcome Yhat of the counterfactual samples in data_count_loader with trained cnet
def find_output(cnet, data_count_loader):

    # data_count_loader contains n_count counterfactual samples [Xcf, Zcf, Acf, Ycf]
    # for every data point in data_count_loader, find counterfactual outcome as cnet([Xcf, Zcf, Acf])
    # save counterfactual outcomes in df

    df = pd.DataFrame()

    for data in data_count_loader:
        inputs, _ = torch.split(data, (data_count_loader.dataset[1].size - 1, 1), 1)
        inputs = inputs.float()
        y_test = cnet(inputs)
        dist_df = pd.DataFrame(columns=["var_Z", "values"])
        dist_df["var_Z"] = inputs[:, 1]
        dist_df["values"] = y_test.item()
        df = pd.concat([df, dist_df])

    return df
