

import matplotlib.pyplot as plt
import numpy as np
import csv
from scipy import stats
from typing import Dict, DefaultDict, List, Any, Tuple, Optional


def plot_results(
        dispersion:float,
        alpha: float,
        max_try: int,
        list_for_fnum: List[int],
        ylim_upper: int,
    ):

    if dispersion == 1.0:
        dispersion = 1
    if dispersion == 0.0:
        dispersion = 0 

    filename_outcome_csv = f'experiments/synthetic_data/csv_data/artificial_data_max_try_{max_try}_phi_{dispersion}_alpha_{alpha}.csv'



    list_for_siblings_fnum = [int(alpha*each_fnum) for each_fnum in list_for_fnum]

    list_for_dnum = [int(list_for_siblings_fnum[i]) for i in range(len(list_for_fnum))]

    suc_num_SDA = []
    suc_num_CP = []
    with open(filename_outcome_csv) as f:
        reader = csv.reader(f)
        l = [row for row in reader]
    l = l[1:]
    # print(l)


    for ind in range(len(list_for_fnum)):
        count_suc_SDA = 0
        count_suc_CP = 0
        for r in range(max_try):
            # print(l[5*ind + r][5])
            count_suc_SDA += int(l[max_try*ind + r][5])
            # print(l[5*ind + r][-1])
            count_suc_CP += int(l[max_try*ind + r][-1])
        suc_num_SDA.append(count_suc_SDA)
        suc_num_CP.append(count_suc_CP)


    fig, ax = plt.subplots(figsize=(15, 6))

    ax.set_xlabel("The Number of Families", fontsize=14)
    ax.set_ylabel(f"Percentage of Success (out of {max_try} trials)", fontsize=14)

    bar_width = 20
    print("suc_num_ESDA = ", suc_num_SDA)
    print("suc_num_CP =", suc_num_CP)
    label_x = ['ESDA','CP']
    list_for_fnum = np.array(list_for_fnum)
    
    left = np.arange(len(suc_num_SDA))  # numpyで横軸を設定
    
    width = 0.3
    
    plt.bar(left, suc_num_SDA, width=width, color='b', label='ESDA', align='center')
    plt.bar(left+width, suc_num_CP,  width=width, color='g', label='CP',  align='center')
    
    plt.xticks(left + width/2, list_for_fnum, fontsize=12)

    plt.yticks(np.arange(0, max_try+5, 10), fontsize=12)
    ax.set_ylim(0, max_try+5)
    plt.legend()
    ax.grid()
    plt.title(f"Dispersion = {dispersion}, alpha = {alpha}", fontsize=14)
    plt.show()


    time_num_SDA = []
    time_num_CP = []
    for ind in range(len(list_for_fnum)):
        time_num_SDA_each = []
        time_num_CP_each = []
        for r in range(max_try):
            time_num_SDA_each.append(float(l[max_try*ind + r][4]))
            time_num_CP_each.append(float(l[max_try*ind + r][-2]))
        time_num_SDA.append(time_num_SDA_each)
        time_num_CP.append(time_num_CP_each)



    # print("time_num_SDA = ", time_num_SDA)

    # print("time_num_CP = ", time_num_CP)


    confidence = 0.95


    time_num_SDA_mean = [np.mean(t_list) for t_list in time_num_SDA]

    time_num_CP_mean = [np.mean(t_list) for t_list in time_num_CP]

    # Log-transform the data
    log_time_num_SDA = [np.log(t_list) for t_list in time_num_SDA]
    log_time_num_CP = [np.log(t_list) for t_list in time_num_CP]

    # Calculate means and error bars for the transformed data

    time_num_SDA_err = [stats.sem(t_list) * stats.t.ppf((1 + confidence) / 2., len(t_list)-1) for t_list in log_time_num_SDA]


    time_num_CP_err = [stats.sem(t_list) * stats.t.ppf((1 + confidence) / 2., len(t_list)-1) for t_list in log_time_num_CP]


    fig, ax = plt.subplots(figsize=(15, 6))
    #plt.yscale('log')
    ax.set_xlabel("The Number of Families", fontsize=14)
    ax.set_ylabel("Time Complexity (seconds)", fontsize=14)

    # print("time_num_CP_err = ", time_num_CP_err)

    ax.errorbar(list_for_fnum, time_num_SDA_mean, time_num_SDA_err, label='ESDA', marker=".", capsize = 5)
    ax.errorbar(list_for_fnum, time_num_CP_mean, time_num_CP_err, label='CP', marker=".", capsize = 5)

    plt.xticks(np.arange(0, list_for_fnum[-1]+100, 500), fontsize=12)
    ax.set_ylim(0.001, ylim_upper)
    ax.tick_params(axis = 'y', which = 'both', labelsize = 12)
    plt.legend()
    ax.grid()
    plt.title(f"Dispersion = {dispersion}, alpha = {alpha}", fontsize=14)
    plt.show()

