import matplotlib.pyplot as plt
from matplotlib.transforms import Bbox
import numpy as np
import os
from pathlib import Path
from os import listdir
from scipy.optimize import approx_fprime
from scipy.stats import beta
import pickle

from utils.constants import *

def sample_one_document(k):
    return np.random.rand(k)
    
    
def beta_params_from_mean_variance(mean, variance):
    assert 0 < variance < mean * (1 - mean), "Variance must be between 0 and mean * (1 - mean)."
    alpha = mean * ((mean * (1 - mean)) / variance - 1)
    beta_param = (1 - mean) * ((mean * (1 - mean)) / variance - 1)
    return alpha, beta_param
    
    
def generate_beta_x0_x_star(k, n, s, one_minus_initial_doc=False, one_minus_info_need=False, var_percentage=0.5):
    """Generates x_0 and x_star base on the given parameters."""
    assert 0 < var_percentage < 1, "var_percentage must be between 0 and 1."
    x_0 = np.zeros((n, k))
    x_0[0] = sample_one_document(k)
    if one_minus_initial_doc:
        mean = np.ones(k) - x_0[0]
    else:
        mean = x_0[0]
    variance = var_percentage * mean * (1 - mean)
    a, b = beta_params_from_mean_variance(mean, variance)
    for i in range(k):
        x_0[1:,i] = beta.rvs(a[i], b[i], size=n-1)
        
    x_star = np.zeros((s, k))
    x_star[0] = sample_one_document(k)
    if one_minus_info_need:
        mean = np.ones(k) - x_star[0]
    else:
        mean = x_star[0]
    variance = var_percentage * mean * (1 - mean)
    a, b = beta_params_from_mean_variance(mean, variance)
    for i in range(k):
        x_star[1:,i] = beta.rvs(a[i], b[i], size=n-1)
    return x_0, x_star


def create_cov_matrix(n, s, rho1, rho2, scaler=1):
    mat = np.zeros((n + s, n + s))
    for i in range(n + s):
        for j in range(n + s):
            if i == j:
                mat[i,j] = 1
            elif i < n and j < n:
                diff = abs(i - j)
                mat[i,j] = rho1 ** diff
            elif i >= n and j >= n:
                diff = abs(i - j)
                mat[i,j] = rho2 ** diff
            else:
                mat[i,j] = 0
    mat = mat * scaler
    return mat


def generate_normal_x0_x_star(k, n, s, rho1, rho2, scaler=0.2):
    """Generates x_0 and x_star base on the given parameters."""
    mean = [0.5] * (n + s)
    cov = create_cov_matrix(n, s, rho1, rho2, scaler)
    x_0 = np.zeros((n, k))
    x_star = np.zeros((s, k))
    for i in range(k):
        res = np.random.multivariate_normal(mean, cov)
        res = np.clip(res, 0, 1)
        x_0[:, i] = res[0:n]
        x_star[:, i] = res[n:n+s]
    return x_0, x_star


def gradient_of_f_at_xi(f, x, i):
    """
    Calculate the gradient of the function f with respect to the i-th element of x.

    :param f: Function that takes an array of shape (n, k) and returns a scalar.
    :param x: An array of shape (n, k) representing n points in k-dimensional space.
    :param i: Index of the element in x with respect to which the gradient is calculated.
    :return: A k-dimensional gradient vector of f at x[i].
    """

    # A wrapper function that isolates x[i] and keeps the rest of x fixed.
    def f_isolated(new_xi):
        x_copy = np.array(x)  # make a copy to avoid modifying the original x
        x_copy[i] = new_xi
        return f(x_copy)

    # Initial point for gradient calculation
    xi = x[i]

    # Calculate the gradient using numerical approximation
    grad = approx_fprime(xi, f_isolated, DELTA_)

    return grad


def is_float(value):
    """Checks if value is a float."""
    try:
        float(value)
        return True
    except ValueError:
        return False


def bootstrap_ci(z, bootstrap_B, alpha=0.05):
    """Returns the bootstrap confidence interval for the mean of z."""
    n = len(z)
    estimators = list()
    for _ in range(bootstrap_B):
        sample = np.random.choice(z, size=n, replace=True)
        estimators.append(np.mean(sample))
    return [np.quantile(estimators, alpha / 2), np.quantile(estimators, 1 - (alpha / 2))]


def save_results(results, desc, ranking_function_lst, additional_param_lst,
                 k_vals, n_vals, s_vals, lam_vals, B, params=None, suffix=''):
    if params is None:
        params = {}
    try:
        Path(SAVE_PATH + '/' + desc).mkdir(parents=True, exist_ok=True)
        args_file_name = f"{SAVE_PATH}/{desc}/{desc}_args.txt"
        with open(args_file_name, 'w') as f:
            f.write(str(k_vals) + "\n")
            f.write(str(n_vals) + "\n")
            f.write(str(s_vals) + "\n")
            f.write(str(lam_vals) + "\n")
            f.write(str(B) + "\n")
            for k, v in params.items():
                f.write(f"{k}: {v}\n")

        for i, r in enumerate(results):
            res_file_name = f"{SAVE_PATH}/{desc}/{ranking_function_lst[i]}_{additional_param_lst[i]}{suffix}.pkl"
            with open(res_file_name, 'wb') as f:
                pickle.dump(r, f)

    except IOError as e:
        print(e)
        print("Couldn't save results")
        print("Results: ", results)


def load_results(res_file_names):
    res = []
    for res_file_name in res_file_names:
        with open(res_file_name, 'rb') as f:
            result_data = pickle.load(f)
            res.append(result_data)
    return res


class YAxis:
    """YAxis object for plotting."""

    def __init__(self, axis_name, index=None, limits=None, ticks=None, legend_loc='upper right', legend_font_size=LEGEND_FONT_SIZE):
        self.axis_name = axis_name
        self.index = index
        if limits == AUTO:
            self.limits = None
        elif limits in AXIS_LIMITS.keys():
            self.limits = AXIS_LIMITS[limits]
        elif limits is None:
            self.limits = AXIS_LIMITS[axis_name] if axis_name in AXIS_LIMITS else None
        else:
            self.limits = limits

        if ticks == AUTO:
            self.ticks = None
        elif ticks in AXIS_TICKS.keys():
            self.ticks = AXIS_TICKS[ticks]
        elif ticks is None:
            self.ticks = AXIS_TICKS[axis_name] if axis_name in AXIS_TICKS else None
        else:
            self.ticks = ticks
        self.legend_loc = legend_loc
        self.legend_font_size = legend_font_size


def create_plots(res, x_vals, x_title, labels, colors_dict, x_ticks=None, graphs=None, figsize=(25, 5), save_path=None, split=False):
    """Creates the plots for the results."""
    if graphs is None:
        graphs = [(PUBLISHERS_WELFARE,), (USERS_WELFARE,), (CONVERGENCE_RATE,)]
    fig, ax = plt.subplots(1, len(graphs), figsize=figsize, sharey=False)

    diff = int(len(res[0][x_vals[0]]) / 2)

    # X ticks
    if x_ticks is not None:
        if type(x_ticks) is str and x_ticks == ALL:
            x_ticks = x_vals

    for i, graph in enumerate(graphs):
        axes = ax[i] if len(graphs) > 1 else ax
        # if graph == None:
        #     continue
        y_axis = YAxis(*graph) if type(graph) is tuple else YAxis(**graph)
        index = y_axis.index if y_axis.index is not None else i

        for j, label in enumerate(labels):
            plot_values = [res[j][x][index] for x in x_vals]
            axes.plot(x_vals, plot_values, marker='o', label=label, color=colors_dict[label])
            plot_values_ci = [res[j][x][index + diff] for x in x_vals]
            for k, x in enumerate(x_vals):
                axes.plot([x, x], plot_values_ci[k], color=colors_dict[label])
        # axes.set_title(y_axis.axis_name)
        axes.set_ylabel(y_axis.axis_name, fontsize=LABELS_FONT_SIZE)
        axes.set_xlabel(x_title, fontsize=LABELS_FONT_SIZE)
        if len(labels) > 1:
            if y_axis.legend_loc is not None:
                axes.legend(loc=y_axis.legend_loc, fontsize=y_axis.legend_font_size)
            else:
                axes.legend(fontsize=y_axis.legend_font_size)
        if y_axis.limits is not None:
            axes.set_ylim(y_axis.limits)

        if y_axis.ticks is not None:
            axes.set_yticks(y_axis.ticks)

        if x_ticks is not None:
            axes.set_xticks(x_ticks)

        axes.tick_params(axis='x', labelsize=TICKS_FONT_SIZE)
        axes.tick_params(axis='y', labelsize=TICKS_FONT_SIZE)

    # plt.tight_layout()
    if save_path is not None:
        if split:
            save_subplots(fig, ax, save_path)
        else:
            plt.savefig(save_path, bbox_inches='tight')
    plt.show()


def save_subplots(fig, axs, save_path, padding_factor=0.05):
    """Saves each subplot independently with manual padding."""
    general_path = save_path.replace('.png', '')
    fig_name = general_path.split('/')[-1]
    os.makedirs(general_path, exist_ok=True)
    for idx, ax in enumerate(axs):
        bbox = ax.get_tightbbox(fig.canvas.get_renderer())
        bbox = bbox.transformed(fig.dpi_scale_trans.inverted())
        bbox = Bbox.from_extents(
            bbox.x0 - padding_factor,  # Left padding
            bbox.y0 - padding_factor,  # Bottom padding
            bbox.x1 + padding_factor,  # Right padding
            bbox.y1 + padding_factor   # Top padding
        )
        subplot_save_path = f"{general_path}/{fig_name}_{idx + 1}.png"
        fig.savefig(subplot_save_path, bbox_inches=bbox)
     
        
def create_plots_from_files(desc, key_x_function, file_label_function, colors_dict, x_ticks=None, graphs=None, 
                            condition=lambda k, n, s, lam: True, labels_sort_key=None, sort=True, figsize=(25, 5), split=False):
    """Creates the plots for the results from pkl files."""
    if graphs is None:
        graphs = [(PUBLISHERS_WELFARE,), (USERS_WELFARE,), (CONVERGENCE_RATE,)]
    if type(desc) is str:
        dir_path = Path(SAVE_PATH + '/' + desc)
        try:
            files_names = listdir(dir_path)
        except FileNotFoundError:
            print("No such directory:", dir_path)
            return
        files_full_path = [dir_path.joinpath(file_name) for file_name in files_names
                           if file_name.endswith('.csv') or file_name.endswith('.pkl')]
    else:  # desc is a list of directories
        files_names = []
        files_full_path = []
        for d in desc:
            dir_path = Path(SAVE_PATH + '/' + d)
            try:
                tmp_files_names = listdir(dir_path)
                files_names += tmp_files_names
            except FileNotFoundError:
                print("No such directory:", dir_path)
                return
            files_full_path += [dir_path.joinpath(file_name) for file_name in tmp_files_names
                                if file_name.endswith('.csv') or file_name.endswith('.pkl')]

    res = load_results(files_full_path)
    x_vals = [key_x_function(k, n, s, lam) for k, n, s, lam in res[0].keys() if condition(k, n, s, lam)]
    x_vals = sorted(x_vals) if sort else x_vals
    res = [{key_x_function(k, n, s, lam): res[i][(k, n, s, lam)] for k, n, s, lam in res[i].keys()
            if condition(k, n, s, lam)}
           for i in range(len(res))]
    res = [dict(sorted(res[i].items(), key=lambda x: x[0])) for i in range(len(res))] if sort else res
    labels = [file_label_function(file_name) for file_name in files_names
              if file_name.endswith('.csv') or file_name.endswith('.pkl')]

    if labels_sort_key is not None:
        combined_list = sorted(zip(res, labels), key=lambda x: labels_sort_key(x[1]))
        res, labels = zip(*combined_list)

    x_title = key_x_function(K_STR, N_STR, S_STR, LAM_STR)

    fig_name = f"{FIGURES_PATH}/{desc}.png" if type(desc) is str else f"{FIGURES_PATH}/{desc[0]}.png"
    create_plots(res, x_vals, x_title, labels, colors_dict, x_ticks, 
                 graphs, figsize, save_path=fig_name, split=split)


def create_single_plot_from_files(desc, file_x_function, color, x_title, x_ticks=None,
                                  graphs=None, functions_sort_key=None, figsize=(25, 5)):
    """Creates the plots for the results from csv files."""
    if graphs is None:
        graphs = [(PUBLISHERS_WELFARE,), (USERS_WELFARE,), (CONVERGENCE_RATE,)]
    dir_path = Path(SAVE_PATH + '/' + desc)
    try:
        files_names = listdir(dir_path)
    except FileNotFoundError:
        print("No such directory:", dir_path)
        return
    files_full_path = [dir_path.joinpath(file_name) for file_name in files_names
                       if file_name.endswith('.csv') or file_name.endswith('.pkl')]
    res = load_results(files_full_path)

    functions = [float(file_x_function(file_name)) for file_name in files_names
                 if file_name.endswith('.csv') or file_name.endswith('.pkl')]
    res = [{np.round(functions[i], 2): list(res[i].values())[0] for i in range(len(res))}]

    if functions_sort_key is not None:
        # res[0] = dict(sorted(res[0], key=lambda x: functions_sort_key(list(x.keys())[0])))
        res[0] = dict(sorted(res[0].items(), key=lambda x: functions_sort_key(x[0])))
        functions = list(res[0].keys())

    create_plots(res, np.round(functions, 2), x_title, [None],
                 {None: color}, x_ticks, graphs, figsize, save_path=f"{FIGURES_PATH}/{desc}.png")