import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib import gridspec
import datetime
from recourse.best_distribution import optimal_capacity

def get_result_folder(approach, gamma, beta, K, seed):
    folder_name = f"{approach}_gamma{gamma}_beta{beta}_K{K}_seed{seed}"
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    full_path = os.path.join("results", f"{timestamp}_{folder_name}")
    os.makedirs(full_path, exist_ok=True)
    
    return full_path

def save_matching_output(model, z, costs, weights, capacity=None, file_path=None):
    n, m = weights.shape
    obj = 0
    lines = []
    
    if model.status == 2:  # GRB.OPTIMAL
        lines.append("Optimal assignment and recommended recourse actions:")
        for i in range(n):
            for j in range(m):
                if z[i, j].X > 0.5:
                    obj += weights[i, j]
                    action = costs[i, j]
                    
                    lines.append(f"Seeker {i} assigned to Bank {j}")
                    lines.append(f"  Recourse cost (minimal change required): {action:.4f}")
        if capacity is not None:
            for j in range(m):
                lines.append(f"Capacity for Bank {j}: {capacity[j]}")
        lines.append(f"Total welfare: {obj:.4f}")
    else:
        lines.append("No optimal solution found.")

    with open(file_path, "w") as f:
        f.write("\n".join(lines))
    print(f"=> Output saved to {file_path}")
    

def generate_welfare_plot(best_dist,IW, n, m, file_path):
        
    nm = n * m
    kk = np.arange(len(best_dist))
    mpl_defaults = dict(font_size=6, axes_linewidth=0.6, legend_fontsize=6)
    plt.rcParams.update({
        "font.size": mpl_defaults["font_size"],
        "axes.labelsize": mpl_defaults["font_size"] + 1,
        "axes.titlesize": mpl_defaults["font_size"] + 1,
        "legend.fontsize": mpl_defaults["legend_fontsize"],
        "axes.linewidth": mpl_defaults["axes_linewidth"],
    })

    fig = plt.figure(figsize=(3.35, 2.0), dpi=300)
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1], wspace=0.05)
    ax_l = fig.add_subplot(gs[0])
    ax_r = fig.add_subplot(gs[1], sharey=ax_l)

    for ax in (ax_l, ax_r):
        ax.plot(kk, best_dist, ls='--', lw=1.0, color='tab:red', label='Social welfare (best distribution)')
        ax.axhline(IW, lw=0.8, color='tab:blue', label='Individual welfare')
        ax.set_ylim(0, IW * 1.05)
        ax.grid(axis='y', color='0.85', lw=0.3, ls=':')
    fig.text(0.5, -0.03, r'Total capacity $K$', ha='center', va='top')

    ax_l.set_xlim(0, n)
    ax_l.set_xticks([0, n])
    ax_l.set_xlabel(r'')
    ax_l.set_ylabel('Welfare')

    ax_r.tick_params(labelleft=False)
    ax_r.set_xlim(nm-1, nm)
    ax_r.set_xticks([nm])
    ax_r.set_xlabel(r'')
    ax_r.yaxis.tick_right()
    ax_r.spines['left'].set_visible(False)
    ax_l.spines['right'].set_visible(False)

    ax_l.yaxis.set_major_locator(mtick.MaxNLocator(nbins=5))
    ax_l.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1f'))

    ax_l.plot(n, best_dist[n], 'ko', ms=3)
    ax_l.annotate(r'$K=n$',  xy=(n,  best_dist[n]), xytext=(0, 8), textcoords='offset points', ha='center', va='bottom')

    ax_r.plot(nm, best_dist[-1], 'ko', ms=3)
    ax_r.annotate(r'$K=nm$', xy=(nm, best_dist[-1]), xytext=(0, 8), textcoords='offset points', ha='center', va='bottom')

    ax_l.annotate(f"{IW:.2f}", xy=(1, IW), xytext=(0, -2), textcoords='offset points', color='tab:blue', ha='center', va='top')

    d = .008
    kwargs = dict(transform=ax_l.transAxes, color='k', clip_on=False)
    ax_l.plot([1-d, 1+d], [-d,  +d], **kwargs)
    ax_l.plot([1-d, 1+d], [1-d, 1+d], **kwargs)
    kwargs.update(transform=ax_r.transAxes)
    ax_r.plot([-d, +d], [1-d, 1+d], **kwargs)
    ax_r.plot([-d, +d], [-d, +d], **kwargs)

    ax_r.legend(loc="lower right", frameon=False, handlelength=1.2)
    fig.tight_layout(pad=0.3)
    plt.savefig(file_path)
    plt.close()
    print(f"=> Welfare plot saved to {file_path}")