import os
import yaml 

import torch
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import networkx as nx

import multiprocessing as mp
from utils import utils
from utils import data

name_dict = {'forbid_prob': 'Forbidden edge probability',
             'sure_prob': 'Sure edge probability',
             'nb_nodes': 'Number of nodes',
             'prob_connection': 'Probability of edge in Erdos-Renyi edge',
             'prob_lower': 'DP-DAG',
             'prob_upper': 'DP-DAG',
             'brute_force_lower': 'brute force',
             'brute_force_upper': 'brute force',
             'lagrangian_lower': 'lagrangian',
             'lagrangian_upper': 'lagrangian'}

# Increase font size
#sns.set_context("paper", font_scale=2)

# Set figure size
# plt.figure(figsize=(12, 8))

# Increase DPI for better resolution (optional)
#plt.rcParams['figure.dpi'] = 150

# Add grid for better readability (optional)
#sns.set_style("ticks")
plt.style.use('science')

textwidth = 6.00117/2
aspect_ratio = 6/8
scale = 1.0
width = textwidth * scale
height = width * aspect_ratio
plt.figure(figsize=(width, height))


# # Adjust tick parameters
plt.tick_params(axis='both', which='major', labelsize=12)
plt.tick_params(axis='both', which='minor', labelsize=10)

# # Increase linewidth for better visibility (optional)
plt.rcParams['axes.linewidth'] = 1.5


def replace_nan_with_tuple(arr):
    if isinstance(arr, np.ndarray) and len(arr) == 2:
        return arr
    else:
        return np.array([np.nan, np.nan])
    
def exchange_min_max_if_required(bounds):
    if bounds[0] > bounds[1]:
        return np.array([bounds[1], bounds[0]])
    else:
        return bounds
    
def filter_by_dict(df, filter_dict):
    for column, value in filter_dict.items():
        df = df[df[column] == value]
    return df

def save_fig(fig, name, figure_path):
    fig.savefig(os.path.join(figure_path, name), bbox_inches='tight')


def is_within_bounds(value, bounds):
    if((bounds[0] - 0.01 <= value)  and (value <= bounds[1] + 0.01)):
        return  True
    else: 
        return False
    
def proportion_of_width_covered(gt_bounds, gt_point, estimated_bounds):
    gt_width = gt_bounds[1] - gt_bounds[0]
    if(gt_width <= 0.05):
        return is_within_bounds(gt_point, estimated_bounds)
    estimated_width = estimated_bounds[1] - estimated_bounds[0]
    if estimated_width < 0:
        return 0
    overlap = max(min(gt_bounds[1], estimated_bounds[1]) - max(gt_bounds[0], estimated_bounds[0]), 0)
    return overlap/gt_width

def bound_tightness(gt_bounds, estimated_bounds):
    estimated_width = estimated_bounds[1] - estimated_bounds[0]
    overlap = max(min(gt_bounds[1], estimated_bounds[1]) - max(gt_bounds[0], estimated_bounds[0]), 0)
    return overlap / estimated_width


def calculate_point_coverage(df):
    df.loc[:,'lag_point_coverage_calc'] = df.apply(lambda row: is_within_bounds(row['ground_truth_effect_calc'], row['lagrangian_bounds']), axis=1)
    df.loc[:,'prob_point_coverage_calc'] = df.apply(lambda row: is_within_bounds(row['ground_truth_effect_calc'], row['prob_bounds']), axis=1)
    df.loc[:,'brute_point coverage_calc'] = df.apply(lambda row: is_within_bounds(row['ground_truth_effect_calc'], row['brute_force_bounds']), axis=1)

    df.loc[:,'lag_point_coverage'] = df.apply(lambda row: is_within_bounds(row['total_ground_truth_effect'], row['lagrangian_bounds']), axis=1)
    df.loc[:,'prob_point_coverage'] = df.apply(lambda row: is_within_bounds(row['total_ground_truth_effect'], row['prob_bounds']), axis=1)
    df.loc[:,'brute_point coverage'] = df.apply(lambda row: is_within_bounds(row['total_ground_truth_effect'], row['brute_force_bounds']), axis=1)
    return df

def calculate_bound_coverage(df):
    df.loc[:,'lag_bound_coverage'] = df.apply(lambda row: proportion_of_width_covered(row['brute_force_bounds'], row['ground_truth_effect_calc'], row['lagrangian_bounds']), axis=1)
    df.loc[:,'prob_bound_coverage'] = df.apply(lambda row: proportion_of_width_covered(row['brute_force_bounds'], row['ground_truth_effect_calc'], row['prob_bounds']), axis=1)
    df.loc[:,'brute_bound_coverage'] = df.apply(lambda row: proportion_of_width_covered(row['brute_force_bounds'], row['ground_truth_effect_calc'], row['brute_force_bounds']), axis=1)
    return df

def calculate_bound_tightness(df):
    df.loc[:,'lag_bound_tightness'] = df.apply(lambda row: bound_tightness(row['brute_force_bounds'], row['lagrangian_bounds']), axis=1)
    df.loc[:,'prob_bound_tightness'] = df.apply(lambda row: bound_tightness(row['brute_force_bounds'], row['prob_bounds']), axis=1)
    df.loc[:,'brute_bound_tightness'] = df.apply(lambda row: bound_tightness(row['brute_force_bounds'], row['brute_force_bounds']), axis=1)
    return df
    
def plot_times(df, category, filename=None, fixed_values=None):
    plt.figure(figsize=(width, height))
    melted_df = df.melt(id_vars=[category], value_vars=['lagrangian_time', 'prob_time', 'brute_force_time'], var_name='Time_Type', value_name='Time')
    melted_df['Time'] = melted_df['Time'] / 60.0
    ax = sns.barplot(data=melted_df, x=category, y='Time', hue='Time_Type', palette='muted')
    xlabel = name_dict[category]
    plt.xlabel(xlabel)
    plt.ylabel('Execution time (minutes)')
    #plt.legend(title='Bounding method', labels=['Lagrangian', 'DP-DAG', 'Brute Force'])
    legend_labels = {'lagrangian_time': 'Lagrangian', 'prob_time': 'Prob', 'brute_force_time': 'Brute Force'}
    handles, _ = ax.get_legend_handles_labels()
    plt.legend(handles, [legend_labels[label.get_label()] for label in handles], title='Method')
    # Adding subtitle if subtitle_dict is provided
    # if fixed_values:
    #     subtitle_text = ", ".join([f"{key}: {value}" for key, value in fixed_values.items()])
    #     plt.suptitle(f'Computational time by {xlabel}')
    #     plt.title(subtitle_text, color='grey')
    # else:
    #     plt.suptitle(f'Computational time by {xlabel}')
    if(category in ['sure_prob', 'forbid_prob']):
        plt.gca().invert_xaxis()
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')  # Adjust dpi for higher quality
    else:
        plt.show()

    plt.show()