import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm, colormaps, font_manager
import seaborn as sns

from convex import Convex

from optimizer import *

def plot3d(F, xl=11, data_type='pop'):  # data types: 'emp', 'pop'
    n = 500
    x = np.linspace(-xl, xl, n)
    y = np.linspace(-xl, xl, n)
    X, Y = np.meshgrid(x, y)

    Xs = torch.Tensor(np.transpose(np.array([list(X.flat), list(Y.flat)]))).double()
    Ys = F.batch_forward(Xs, data_type)
    fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.grid(False)
    Yv = Ys.mean(1).view(n, n)
    surf = ax.plot_surface(X, Y, Yv.numpy(), cmap=cm.viridis)

    zmin = Ys.mean(1).min()
    zmax = Ys.mean(1).max()

    ax.set_zticks([-16, -8, 0, 8])  # Original
    ax.set_zlim(-20, 10)  # Original
    # ax.set_zticks([zmin, zmin/2, 0, zmax/2, zmax])
    # ax.set_zlim(zmin-1, zmax+1)

    ax.set_xticks([-10, 0, 10])
    ax.set_yticks([-10, 0, 10])
    for tick in ax.xaxis.get_major_ticks():
        tick.label1.set_fontsize(15)
    for tick in ax.yaxis.get_major_ticks():
        tick.label1.set_fontsize(15)
    for tick in ax.zaxis.get_major_ticks():
        tick.label1.set_fontsize(15)

    ax.view_init(25)
    plt.tight_layout()
    plt.savefig(f"./imgs/_3d-obj-{data_type}.png", dpi=1000)


# plot Pareto statioanrity at each point of the support
def plot3d_PS(F, xl=11):  # data types: 'emp', 'pop'
    n = 500
    x = np.linspace(-xl, xl, n)
    y = np.linspace(-xl, xl, n)
    X, Y = np.meshgrid(x, y)

    Xs = torch.tensor(np.transpose(np.array([list(X.flat), list(Y.flat)])), dtype=torch.double, requires_grad=True)

    Ys, Gs, Gs_emp = F.batch_forward(Xs, compute_grad=True)
    fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
    PS = []
    PS_emp = []
    for k, g, g_emp in enumerate(zip(Gs, Gs_emp)):
        g_mgd_ = mgd(g)
        PS.append(torch.norm(g_mgd_))
        g_mgd_ = mgd(g_emp)
        PS_emp.append(torch.norm(g_mgd_))

    PS = torch.stack(PS).view(n, n)
    PS_emp = torch.stack(PS_emp).view(n, n)

    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.grid(False)

    Yv = Ys.mean(1).view(n, n)

    surf = ax.plot_surface(X, Y, PS.numpy(), cmap=cm.viridis)

    zmin = Ys.mean(1).min()
    zmax = Ys.mean(1).max()

    ax.set_zticks([-16, -8, 0, 8])  # Original
    ax.set_zlim(-20, 10)  # Original
    # ax.set_zticks([zmin, zmin/2, 0, zmax/2, zmax])
    # ax.set_zlim(zmin-1, zmax+1)

    ax.set_xticks([-10, 0, 10])
    ax.set_yticks([-10, 0, 10])
    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_fontsize(15)
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(15)
    for tick in ax.zaxis.get_major_ticks():
        tick.label.set_fontsize(15)

    ax.view_init(25)
    plt.tight_layout()
    plt.savefig(f"./imgs/_3d-PS-pop.png", dpi=1000)

    # plot 2d contours
    fig = plt.figure()
    ax = fig.add_subplot(111)
    c = plt.contour(X, Y, PS, cmap=cm.viridis,
                    linewidths=4.0, linestyles='dotted')
    c1 = plt.contour(X, Y, PS_emp, cmap=cm.viridis, linewidths=4.0)
    ax.set_aspect(1.0 / ax.get_data_ratio(), adjustable='box')
    plt.xticks([-10, -5, 0, 5, 10], fontsize=15)
    plt.yticks([-10, -5, 0, 5, 10], fontsize=15)
    plt.tight_layout()
    plt.savefig(f"./imgs/_2d-PS.png", dpi=100)
    plt.close()


def plot_contour(F, emp=True, task=1, traj=None,
                 xl=11, levels=12, plotbar=False, name="tmp", args=None):
    #
    # rc('font',**{'family':'serif','serif':['Times']})
    font_manager.findfont("Times New Roman")
    plt.rcParams['font.family'] = ['serif']
    plt.rcParams['font.serif'] = ['Times New Roman']
    tick_fontsize = 22
    label_fontsize = 28

    n = 500
    x = np.linspace(-xl, xl, n)
    y = np.linspace(-xl, xl, n)

    X, Y = np.meshgrid(x, y)

    if task == 2:
        fig = plt.figure(figsize=(7, 6))
    else:
        if plotbar:
            fig = plt.figure(figsize=(7, 6))
        else:
            fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111)
    Xs = torch.Tensor(np.transpose(np.array([list(X.flat), list(Y.flat)]))).double()
    Ys = F.batch_forward(Xs)
    Ys_emp = F.batch_forward(Xs, data_type='emp')

    cmap = colormaps.get_cmap('viridis')

    Ys1 = Ys[:, 0].numpy()
    meshy1 = np.argmin(Ys1) // n
    meshx1 = np.argmin(Ys1) % n
    yy1 = y[meshy1]
    xx1 = x[meshx1]

    Ys2 = Ys[:, 1].numpy()
    meshy2 = np.argmin(Ys2) // n
    meshx2 = np.argmin(Ys2) % n
    yy2 = y[meshy2]
    xx2 = x[meshx2]

    # Added this block to remove hardcoded calcs
    Ysmean = Ys.mean(1).numpy()
    Ysmin = np.min(Ysmean)
    Ysargmin = np.argmin(Ysmean)
    meshy = np.argmin(Ysmean) // n
    meshx = np.argmin(Ysmean) % n

    yy = y[meshy]
    xx = x[meshx]

    # plot mean objective
    if task == 0:
        # get mean of objectives
        Yv = Ys.mean(1)
        Yv_emp = Ys_emp.mean(1)
        # intial points
        plt.plot(args.init1[0], args.init1[1], marker='o', markersize=10, zorder=5, color='k')
        plt.plot(args.init2[0], args.init2[1], marker='o', markersize=10, zorder=5, color='k')
        plt.plot(args.init3[0], args.init3[1], marker='o', markersize=10, zorder=5, color='k')
        # pareto front
        plt.plot([xx1, xx2], [yy1, yy2], linewidth=8.0, zorder=0,
                 color='green', alpha=0.5)
        # optimum of mean loss

        # Original
        plt.plot(xx, yy, marker='*', markersize=15, zorder=5, color='g')
        c = plt.contour(X, Y, Yv.view(n, n), cmap=cmap,
                        linewidths=3., levels=levels,
                        linestyles='dotted')
        plt.contour(X, Y, Yv.view(n, n), cmap=cmap,
                          levels=levels, alpha=0.3,
                          linewidths=0.5,
                          linestyles='dotted')
    # plot objective 1
    elif task == 1:
        # get first objective values
        Yv_emp = Ys_emp[:, 0]
        # optimum of loss 1
        plt.plot(xx1, yy1, marker='*', markersize=15, zorder=5, color='g')
        c = plt.contour(X, Y, Ys[:, 0].view(n, n), cmap=cmap,
                        linewidths=3., levels=levels,
                        linestyles='dotted')
        plt.contour(X, Y, Ys[:, 0].view(n, n), cmap=cmap,
                          levels=levels, alpha=0.3,
                          linewidths=0.5, linestyles='dotted')

    # plot objective 2
    elif task == 2:
        # get second objective values
        Yv_emp = Ys_emp[:, 1]
        # optimum of loss2
        plt.plot(xx2, yy2, marker='*', markersize=15, zorder=5, color='g')
        c = plt.contour(X, Y, Ys[:, 1].view(n, n), cmap=cmap,
                        linewidths=3., levels=levels,
                        linestyles='dotted')
        plt.contour(X, Y, Ys[:, 1].view(n, n), cmap=cmap,
                          levels=levels, alpha=0.3,
                          linewidths=0.5,
                          linestyles='dotted')

        plt.ylabel(r"$x_2$", fontsize=label_fontsize)

    if emp:
        c_emp = plt.contour(X, Y, Yv_emp.view(n, n), cmap=cmap,
                            levels=levels, linewidths=4.0)

        Ys1 = Ys_emp[:, 0].numpy()
        meshy1 = np.argmin(Ys1) // n
        meshx1 = np.argmin(Ys1) % n
        yy1 = y[meshy1]
        xx1 = x[meshx1]

        Ys2 = Ys_emp[:, 1].numpy()
        meshy2 = np.argmin(Ys2) // n
        meshx2 = np.argmin(Ys2) % n
        yy2 = y[meshy2]
        xx2 = x[meshx2]

        meshy = np.argmin(Yv_emp) // n
        meshx = np.argmin(Yv_emp) % n
        yy = y[meshy]
        xx = x[meshx]

        if task == 0:
            plt.plot([xx1, xx2], [yy1, yy2],
                     linewidth=8.0, zorder=0, color='gray')
        plt.plot(xx, yy, marker='*', markersize=15, zorder=5, color='k')

    if traj is not None:
        for tt in traj:
            l = tt.shape[0]
            color_list = np.zeros((l, 3))
            color_list[:, 0] = 1.
            color_list[:, 1] = np.linspace(0, 1, l)

            ax.scatter(tt[:, 0], tt[:, 1], color=color_list, s=6, zorder=10)

    if plotbar:
        cbar = fig.colorbar(c, ticks=[-18, -15, -13, -10, -5, 0, 3, 5])
        cbar.ax.tick_params(labelsize=tick_fontsize)

    ax.set_aspect(1.0 / ax.get_data_ratio(), adjustable='box')
    plt.xticks([-10, -5, 0, 5, 10], fontsize=tick_fontsize)
    plt.yticks([-10, -5, 0, 5, 10], fontsize=tick_fontsize)
    plt.xlabel(r"$x_1$", fontsize=label_fontsize)
    plt.ylabel(r"$x_2$", fontsize=label_fontsize)
    plt.tight_layout()
    plt.savefig(f"{name}.png", dpi=100)
    plt.close()


def plot_2d_pareto(method, out_path="", data_type='pop', args = None, index = None):
    sns.set_style("darkgrid", {"grid.linewidth": 1, "grid.color": "1",
                               'axes.facecolor': 'lightsteelblue'})
    font_manager.findfont("Times New Roman")
    plt.rcParams['font.family'] = ['serif']
    plt.rcParams['font.serif'] = ['Times New Roman']
    tick_fontsize = 22
    label_fontsize = 28

    if method == "modo":
        folder_name = "results/"+str(args.number) +"/"
    elif method == "sgd":
        folder_name = "results/static/"
    elif method == "MGDA":
        folder_name = "results/MGDA/"

    t1 = torch.load(folder_name + args.toy +"0-runs1.pt", weights_only=False)
    t2 = torch.load(folder_name + args.toy +"1-runs1.pt", weights_only=False)
    t3 = torch.load(folder_name + args.toy +"2-runs1.pt", weights_only=False)

    trajectories = {1: t1, 2: t2, 3: t3}
    fig, ax = plt.subplots(figsize=(6, 5))

    F = Convex(args)

    losses = []
    for res in trajectories.values():
        losses.append(F.batch_forward(res[method][0], data_type=data_type))
        # losses.append(F.batch_forward(torch.from_numpy(res[method])))

    n = 1000
    xl = 11
    x = np.linspace(-xl, xl, n)
    y = np.linspace(-xl, xl, n)
    X, Y = np.meshgrid(x, y)
    Xs = torch.Tensor(np.transpose(np.array([list(X.flat), list(Y.flat)]))).double()
    Ys = F.batch_forward(Xs, data_type=data_type)

    Ys1 = Ys[:, 0].numpy()
    meshy1 = np.argmin(Ys1) // n
    meshx1 = np.argmin(Ys1) % n
    yy1 = y[meshy1]
    xx1 = x[meshx1]

    Ys2 = Ys[:, 1].numpy()
    meshy2 = np.argmin(Ys2) // n
    meshx2 = np.argmin(Ys2) % n
    yy2 = y[meshy2]
    xx2 = x[meshx2]

    x = np.linspace(xx1, xx2, 200)
    y = np.linspace(yy1, yy2, 200)

    inpt = np.stack((x, y)).T
    Xps = torch.from_numpy(inpt).double()

    Yps = F.batch_forward(Xps, data_type=data_type)
    if data_type == 'emp':
        color_PF = "#72727A"
        alpha_PF = 1
        label_PF = "Empirical Pareto Front"
    elif data_type == 'pop':
        color_PF = "g"
        alpha_PF = 0.9
        label_PF = "Population Pareto Front"

    ax.plot(
        Yps.numpy()[:, 0],
        Yps.numpy()[:, 1],
        "-",
        linewidth=8,
        color=color_PF,
        label=label_PF,
        alpha=alpha_PF
    )  # Pareto front

    count_s = 0
    for i, tt in enumerate(losses):
        # print(tt[0, 0])
        count_s += 1
        ax.scatter(
            tt[0, 0], tt[0, 1],
            color="k",
            s=150,
            zorder=10,
            label="Initial Iterate" if i == 0 else None,
        )
        ttt = tt[0:50000]
        if i == 0:
            colors = matplotlib.cm.viridis(np.linspace(0.0, 1, ttt.shape[0]))
        elif i == 1:
            colors = matplotlib.cm.plasma_r(np.linspace(1, 0., ttt.shape[0]))
        elif i == 2:
            colors = matplotlib.cm.autumn(np.linspace(0.0, 1, ttt.shape[0]))

        # print(tt.shape)
        ax.scatter(ttt[:, 0], ttt[:, 1], color=colors, s=5, zorder=9)
        ax.scatter(ttt[-1, 0], ttt[-1, 1], color='yellow', s=150,
                   alpha=0.7, zorder=10,
                   label="Last Iterate" if i == 0 else None, )

    sns.despine()
    if data_type == 'pop':
        ax.set_xlabel(r"$f_1$", size=label_fontsize)
        ax.set_ylabel(r"$f_2$", size=label_fontsize)
    elif data_type == 'emp':
        ax.set_xlabel(r"$f_{S,1}$", size=label_fontsize)
        ax.set_ylabel(r"$f_{S,2}$", size=label_fontsize)

    for tick in ax.xaxis.get_major_ticks():
        tick.label1.set_fontsize(20)
    for tick in ax.yaxis.get_major_ticks():
        tick.label1.set_fontsize(20)

    plt.tight_layout()

    title_map = {
        "nashmtl": "Nash-MTL",
        "cagrad": "CAGrad",
        "mgd": "MGDA",
        "pcgrad": "PCGrad",
        "smgd": "SMG",
        "tracking": "SMGDC (ours)",
        "sgd": "Mean",
    }

    if method == "modo":
        legend = ax.legend(
            loc=2, bbox_to_anchor=(0.2, 0.65),
            frameon=True, fontsize=20,
            framealpha=0.5,
        )
        legend.set_zorder(10)
        legend.get_frame().set_edgecolor('k')
        legend.get_frame().set_linewidth(1.0)
        ax.set_zorder(1)

    plt.savefig(
        out_path + f"{method}-{data_type}"+index+"-os.png",
        bbox_inches="tight",
        facecolor="white",
    )
    return

