import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

def read_tsv(tsv_file):
    df = pd.read_csv(tsv_file, sep=',', low_memory=False)
    return df

def plot_CD(make_plot=False):

    if not make_plot:
        return
    seeds = [1, 2, 5, 6, 7, 8, 9, 10, 11, 12]
    cd_ilcm = [0.69, 0.74, 0.67, 0.65, 0.32, 0.73, 0.70, 0.50, 0.56, 0.49]
    cd_betavae = [0.39, 0.15, 0.25, 0.19, 0.25, 0.51, 0.35, 0.48, 0.31, 0.15]
    cd_dvae = [0.54, 0.72, 0.51, 0.50, 0.44, 0.62, 0.49, 0.54, 0.68, 0.39]

    plt.figure(figsize=(10, 6))  # Adjust the figure size as needed
    plt.plot(seeds, cd_ilcm, label='ILCM', color='blue', marker='o')
    plt.plot(seeds, cd_betavae, label='BetaVAE', color='green', marker='s')
    plt.plot(seeds, cd_dvae, label='DVAE', color='red', marker='^')
    plt.legend()
    plt.xlabel('Seeds')
    plt.ylabel('Causal Disentanglement Score')
    plt.title("Models' Disentanglement Scores")
    plt.grid(True)
    plt.show()


def plot_CC(make_plot=False):

    if not make_plot:
        return
    seeds = [1, 2, 5, 6, 7, 8, 9, 10, 11, 12]
    cc_ilcm = [0.76, 0.79, 0.77, 0.76, 0.37, 0.73, 0.88, 0.62, 0.74, 0.50]
    cc_betavae = [0.53, 0.24, 0.54, 0.41, 0.45, 0.63, 0.65, 0.61, 0.40, 0.42]
    cc_dvae = [0.68, 0.77, 0.56, 0.69, 0.54, 0.69, 0.74, 0.63, 0.76, 0.56]

    plt.figure(figsize=(10, 6))  # Adjust the figure size as needed
    plt.plot(seeds, cc_ilcm, label='ILCM', color='blue', marker='o')
    plt.plot(seeds, cc_betavae, label='BetaVAE', color='green', marker='s')
    plt.plot(seeds, cc_dvae, label='DVAE', color='red', marker='^')
    plt.legend()
    plt.xlabel('Seeds')
    plt.ylabel('Causal Completeness Score')
    plt.title("Models' Completeness Scores")
    plt.grid(True)
    plt.show()


def plot_m_std(ilcm_res, softilcm_res):

    means_ilcm = np.mean(ilcm_res, axis=1)
    stds_ilcm = np.std(ilcm_res, axis=1)

    data = softilcm_res
    means = np.mean(data, axis=1)
    stds = np.std(data, axis=1)
    row_indices = np.arange(len(means))
    fig, ax = plt.subplots()
    # ax.errorbar(row_indices, means, yerr=stds, fmt='o', label='Mean ± Std')
    # ax.errorbar(row_indices, means, yerr=stds, fmt='o', label='Mean ± Std', alpha=0.5)  # Adjust alpha value

    ax.plot(row_indices, means, marker='o', label='Mean-Soft-ILCM', color='red')
    ax.fill_between(row_indices, means - stds, means + stds, alpha=0.5, color=(1.0, 0.6, 0.76, 1),
                    label='Std-Soft-ILCM')

    ax.plot(row_indices, means_ilcm, marker='o', label='Mean-ILCM', color='green')
    ax.fill_between(row_indices, means_ilcm - stds_ilcm, means_ilcm + stds_ilcm, alpha=0.5, color=(0, 0.6, 0, 1),
                    label='Std-ILCM')

    ax.set_xticks(row_indices)
    ax.set_xticklabels(['D4', 'D5', 'D6', 'D7', 'D8', 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15'], fontsize=12)
    ax.set_xlabel('Causal Variables', fontsize=15)
    ax.set_ylabel('Causal Disentanglement Score', fontsize=15)
    ax.set_title('Means with Standard Deviations', fontsize=15)
    ax.tick_params(axis='both', which='both', labelsize=12)
    ax.legend(fontsize=14)
    plt.tight_layout()
    plt.show()


def plot_best_worst(ilcm_res, softilcm_res):

    best_ilcm = np.max(ilcm_res, axis=1)
    worst_ilcm = np.min(ilcm_res, axis=1)

    best_softilcm = np.max(softilcm_res, axis=1)
    worst_softilcm = np.min(softilcm_res, axis=1)

    row_indices = np.arange(len(ilcm_res))

    # Create a 1x2 subplot layout
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))  # 1 row, 2 columns

    # Plot Best Performance
    axs[0].plot(row_indices, best_softilcm, marker='o', label='Soft-ILCM', color='red')
    axs[0].plot(row_indices, best_ilcm, marker='o', label='ILCM', color='green')
    axs[0].set_xticks(row_indices)
    axs[0].set_xticklabels(['D4', 'D5', 'D6', 'D7', 'D8', 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15'], fontsize=12)
    axs[0].set_xlabel('Causal Variables', fontsize=15)
    axs[0].set_ylabel('Causal Disentanglement Score', fontsize=15)
    axs[0].set_title('Best Performance', fontsize=15)
    axs[0].tick_params(axis='both', which='both', labelsize=12)
    axs[0].legend(fontsize=14)

    # Plot Worst Performance
    axs[1].plot(row_indices, worst_softilcm, marker='o', label='Soft-ILCM', color='red')
    axs[1].plot(row_indices, worst_ilcm, marker='o', label='ILCM', color='green')
    axs[1].set_xticks(row_indices)
    axs[1].set_xticklabels(['D4', 'D5', 'D6', 'D7', 'D8', 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15'], fontsize=12)
    axs[1].set_xlabel('Causal Variables', fontsize=15)
    axs[1].set_ylabel('Causal Disentanglement Score', fontsize=15)
    axs[1].set_title('Worst Performance', fontsize=15)
    axs[1].tick_params(axis='both', which='both', labelsize=12)
    axs[1].legend(fontsize=14)

    plt.tight_layout()
    plt.show()


def get_res_var_seeds(log_path='', make_plot=False):

    if not make_plot:
        return

    res = read_tsv(log_path)
    cat = res.columns.to_list()
    print(res.columns.to_list())
    res_cs = res['Causal Disentanglement'].to_list()
    matrix = list(np.array(res_cs).reshape(24, 11))
    res_mat = []
    for d in matrix:
        res_mat.append(d[:-1])

    ilcm_res = res_mat[::2]
    softilcm_res = res_mat[1::2]
    plot_m_std(ilcm_res, softilcm_res)
    plot_best_worst(ilcm_res, softilcm_res)
    print('well done!')


if __name__ == "__main__":

    plot_CD(make_plot=False)
    plot_CC(make_plot=False)
    get_res_var_seeds(log_path='/home/zahra/Desktop/experiments/average_metrics.tsv', make_plot=True)