import numpy as np
import pandas as pd
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch
import Code.HSCIC as hs


# Functions to convert to indices
def sample_latent(size, latents_sizes):
    samples = np.zeros((size, latents_sizes.size))
    for lat_i, lat_size in enumerate(latents_sizes):
        samples[:, lat_i] = np.random.randint(lat_size, size=size)
    return samples


def latent_to_index(latents, latents_bases):
    return np.dot(latents, latents_bases).astype(int)


# Generate new dataset of latent variables via matching following defined SCM
def find_causal_dataset(n_small, n_large, latent_sizes, latents_bases, imgs):
    df_total = pd.DataFrame(sample_latent(n_large, latent_sizes), columns=['color', 'shape', 'scale', 'orientation',
                                                                           'posX', 'posY'])
    df_small = pd.DataFrame(sample_latent(n_small, latent_sizes), columns=['color', 'shape', 'scale', 'orientation',
                                                                           'posX', 'posY'])

    df_total = df_total.astype('Int64')
    df_small = df_small.astype('Int64')

    # Position X
    values_posX = np.round(np.random.normal(df_small['shape'] + df_small['posY'], 1))
    values_posX = np.where(values_posX < 0, 0, values_posX)
    values_posX = np.where(values_posX > 32, 32, values_posX)

    fr_sh_pos = pd.DataFrame(np.vstack((df_small['shape'], df_small['posY'], values_posX.astype(int))).T,
                             columns=['shape', 'posY', 'posX'])

    # Scale
    noise_scale = np.random.normal(0, 1, size= n_small)
    values_scale = (fr_sh_pos['posX']/24 + fr_sh_pos['posY']/24) * fr_sh_pos['shape'] + noise_scale
    values_scale = np.where(values_scale < 0, 0, values_scale)
    values_scale = np.where(values_scale > 6, 5, values_scale)

    fr_sh_pos['scale'] = values_scale.astype(int)
    fr_sh_pos['noise_scale'] = noise_scale

    df_all_merged = df_total.merge(fr_sh_pos, on=['shape', 'posX', 'posY', 'scale'],
                                   how='left', indicator=True)
    df_total_final = df_all_merged[df_all_merged['_merge'] == 'both']
    df_total_final = df_total_final.drop('_merge', axis=1)
    indices_sampled = latent_to_index(df_total_final.drop('noise_scale', axis=1), latents_bases)
    imgs_sampled_causal = imgs[indices_sampled]

    # df_total_final: ['color', 'shape', 'scale', 'orientation', 'posX', 'posY', 'output']
    df_total_final['output'] = df_total_final.apply(lambda row: np.exp(row[1])*row[4]+row[2]**2*np.sin(row[5]), axis=1)\
                               + np.random.normal(0, 0.01, size=imgs_sampled_causal.shape[0])
    return df_total_final


# Dataframe variables normalization given mean and variance
def normalize_data(df, mean, var):
    normalized_df = (df - mean) / var
    return normalized_df


# Apply DataLoader on images, latents variables and labels
def get_loaders(df_tabular, imgs, batch_size):  # as input the datasets df_tabular:[latent variables, Y] and imgs:[images]

    # Image to a Torch tensor
    transform = transforms.ToTensor()

    # 80/20 train/test data split
    rnd = np.random.uniform(0, 1, len(df_tabular))
    train_idx = np.array(np.where(rnd < 0.8)).flatten()
    test_idx = np.array(np.where(rnd >= 0.8)).flatten()

    # Split in train/test
    d_tab = np.array(df_tabular.drop('output', axis=1)).astype(np.float32)
    tabular_tensor_train = transform(d_tab[train_idx]).reshape(-1, 5)
    tabular_tensor_test = transform(d_tab[test_idx]).reshape(-1, 5)
    img_tensor_train = transform(imgs[train_idx]).reshape(-1, 64, 64)
    img_tensor_test = transform(imgs[test_idx]).reshape(-1, 64, 64)
    d_lab = np.array(df_tabular['output']).astype(np.float32).reshape(-1, 1)
    labels_train = transform(d_lab[train_idx]).reshape(-1)
    labels_test = transform(d_lab[test_idx]).reshape(-1)

    # Get loaders
    loader_trainer_img = DataLoader(img_tensor_train, batch_size=batch_size, num_workers=0)
    loader_trainer_tab = DataLoader(tabular_tensor_train, batch_size=batch_size, num_workers=0)
    loader_trainer_lab = DataLoader(labels_train, batch_size=batch_size, num_workers=0)
    loader_test_img = DataLoader(img_tensor_test, batch_size=batch_size, num_workers=0)
    loader_test_tab = DataLoader(tabular_tensor_test, batch_size=batch_size, num_workers=0)
    loader_test_lab = DataLoader(labels_test, batch_size=batch_size, num_workers=0)

    return loader_trainer_img, loader_trainer_tab, loader_trainer_lab, loader_test_img, loader_test_tab, loader_test_lab


# Train model with Loss = (MSE Loss) + beta * (HSCIC loss)
def train_model(loader_train_img, loader_train_tab, loader_train_lab, model, optimizer, beta, num_epochs,
                loader_test_img, loader_test_tab, loader_test_lab, loss_function):
    total_loss_test = []
    total_loss_train = []
    hscic_train = []
    accuracy_loss_train = []
    hscic_test = []
    accuracy_loss_test = []
    hscic = hs.HSCIC()

    for epoch in range(1, num_epochs):
        hscic_train_epoch = []
        acc_train_epoch = []
        epoch_loss = []
        epoch_loss_test = []
        acc_test_epoch = []
        hscic_test_epoch = []

        for i, data in enumerate(zip(loader_train_img, loader_train_tab, loader_train_lab)):
            optimizer.zero_grad()

            imgs, tabular, output = data
            imgs = imgs.reshape(-1, 1, 64, 64)
            imgs = imgs
            numeric_features = tabular
            output = output

            y = model(imgs, numeric_features)

            loss = loss_function(y.flatten(), output)

            # HSCIC(Y, position X, [shape, position Y])
            loss_hscic = hscic(y, torch.cat((tabular[:, 4].reshape(y.shape[0], 1),
                                             tabular[:, 2].reshape(y.shape[0], 1)), 1),
                               torch.cat((tabular[:, 1].reshape(y.shape[0], 1),
                                                            tabular[:, 5].reshape(y.shape[0], 1)), 1))

            total_loss = loss + beta * loss_hscic
            hscic_train_epoch.append(loss_hscic.item())
            acc_train_epoch.append(loss.item())
            epoch_loss.append(total_loss.item())

            total_loss.backward()
            optimizer.step()

        hscic_train.append(sum(hscic_train_epoch) / len(hscic_train_epoch))
        accuracy_loss_train.append(sum(acc_train_epoch) / len(acc_train_epoch))
        total_loss_train.append(sum(epoch_loss) / len(epoch_loss))

        print('==> Epoch {}, Accuracy Loss {:.6f}, HSCIC Loss {:.6f}'.format(epoch, sum(acc_train_epoch) / len(acc_train_epoch), sum(hscic_train_epoch) / len(hscic_train_epoch)))
        print('==> Epoch {}, Average Loss {:.6f}'.format(epoch, sum(epoch_loss) / len(epoch_loss)))

        for k, data_test in enumerate(zip(loader_test_img, loader_test_tab, loader_test_lab)):
            imgs, numeric_features, output = data_test
            imgs = imgs.reshape(-1, 1, 64, 64)

            y_test = model(imgs, numeric_features)
            loss_hscic_test = hscic(y_test, torch.cat((numeric_features[:, 4].reshape(y_test.shape[0], 1),
                                numeric_features[:, 2].reshape(y_test.shape[0], 1)), 1),
                  torch.cat((numeric_features[:, 1].reshape(y_test.shape[0], 1),
                             numeric_features[:, 5].reshape(y_test.shape[0], 1)), 1))

            loss = loss_function(y_test.flatten(), output)
            total_loss = loss + beta * loss_hscic_test
            epoch_loss_test.append(total_loss.item())
            hscic_test_epoch.append(loss_hscic_test.item())
            acc_test_epoch.append(loss.item())

        hscic_test.append(sum(hscic_test_epoch) / len(hscic_test_epoch))
        accuracy_loss_test.append(sum(acc_test_epoch) / len(acc_test_epoch))
        total_loss_test.append(sum(epoch_loss_test) / len(epoch_loss_test))

    return model, total_loss_train, accuracy_loss_train, hscic_train, total_loss_test, accuracy_loss_test, hscic_test


def generate_counterfactuals(n_count, mean, var, imgs, df_latent, latents_bases):
    # select random data point from df_latent
    index = np.random.randint(low=0, high=len(df_latent), size=1)
    shape = df_latent.loc[index, 'shape']
    posY = df_latent.loc[index, 'posY']
    color = 0
    orientation = df_latent.loc[index, 'orientation']
    noise_scale = df_latent.loc[index, 'noise_scale']

    # generate counterfactuals of posX and scale
    count_posx = df_latent.loc[np.random.randint(low=0, high=len(df_latent), size=n_count), 'posX']
    count_values_scale = (count_posx / 24 + np.full(n_count, posY).flatten() / 24) * np.full(n_count, shape).flatten() + np.full(n_count, noise_scale).flatten()
    count_values_scale = np.where(count_values_scale < 0, 0, count_values_scale)
    count_values_scale = np.where(count_values_scale > 5, 5, count_values_scale)

    # store counterfactual latent variables in df_count
    df_count = pd.DataFrame(np.vstack((np.full(n_count, color).flatten(), np.full(n_count, shape).flatten(),
                                       count_values_scale, np.full(n_count, orientation).flatten(), count_posx,
                                       np.full(n_count, posY).flatten())).T,
                            columns=['color', 'shape', 'scale', 'orientation', 'posX', 'posY'])
    df_count = df_count.astype('int32')

    indices_sampled_count = latent_to_index(df_count, latents_bases)
    imgs_sampled_causal_count = imgs[indices_sampled_count]  # select counterfactual images

    df_count['output'] = np.full(n_count, 0).reshape(-1, 1)  # counterfactual outcome not needed
    df_count = df_count.drop(['scale'], axis=1)

    norm_df_count = normalize_data(df_count, mean, var)  # normalize counterfactuals
    norm_df_count = norm_df_count.drop(['output'], axis=1)

    return norm_df_count, imgs_sampled_causal_count


# Generate n_count counterfactuals for n_samples datapoints with function generate_counterfactuals
def counterfactual_simulations(n_samples, n_count, mean, var, imgs, df_total_final, latents_bases):
    loader_trainer_images = {}
    loader_trainer_tabular = {}

    for i in range(n_samples):
        transform = transforms.ToTensor()

        # df_count: counterfactual latent variables, imgs_sampled_causal_count: counterfactual images
        df_count, imgs_sampled_causal_count = \
            generate_counterfactuals(n_count, mean, var, imgs, df_total_final, latents_bases)
        df_count['color'] = 0

        # get loaders from df_count and imgs_sampled_causal_count
        # store them respectively in loader_trainer_tabular and loader_trainer_images
        d_tab = np.array(df_count).astype(np.float32)
        tabular_tensor_count = transform(d_tab).reshape(-1, 5)
        tensor_imgs_count = transform(imgs_sampled_causal_count).reshape(-1, 64, 64)
        loader_trainer_images[i] = DataLoader(tensor_imgs_count, batch_size=1, num_workers=0)
        loader_trainer_tabular[i] = DataLoader(tabular_tensor_count, batch_size=1, num_workers=0)

    return loader_trainer_images, loader_trainer_tabular


# Find outcome Yhat for counterfactual samples and store them in DataFrame df
def find_output(model, loader_trainer_images, loader_trainer_tabular):
    tot_df = []

    for i, data in enumerate(zip(loader_trainer_images, loader_trainer_tabular)):
        imgs, tabular = data
        imgs = imgs.reshape(-1, 1, 64, 64)
        imgs = imgs

        numeric_features = tabular

        y = model(imgs, numeric_features)

        dist_df = [y[0][0].detach().item()]
        tot_df = np.concatenate((tot_df, dist_df), axis=0)

    df = pd.DataFrame(tot_df, columns=['values'])
    return df
