import os
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch

def savefig(name, transparent=False, pdf=False):
    FIG_ROOT = "figures"
    os.makedirs(FIG_ROOT, exist_ok=True)
    modes = ["png"]
    if pdf:
        modes += ["pdf"]
    for mode in modes:
        file_name = f"{FIG_ROOT}/{name}.{mode}"
        if transparent:
            plt.savefig(file_name, dpi=300, bbox_inches="tight", transparent=True)
        else:
            plt.savefig(file_name, dpi=300, facecolor="white", bbox_inches="tight")
    plt.show()
    plt.clf()

def plot_current_status(args, net, t, xent, repulsion_loss, training_data, test_data, exp_name):
    if t % args.log_every == 0:
        print(f"{t=} xent {xent.item():.5f} aux {repulsion_loss.item():.5f}")
    
    slopes = []

    times = sorted([2**n for n in range(15)] + [1000 * n for n in range(200)])
    times = [t for t in times if t < args.train_iter and t > 0]
    if t in times:
        plt.figure(figsize=(4, 4))

        if isinstance(net, nn.Linear):
            weights = net.weight.detach().cpu()
            bias = net.bias.detach().cpu()
        else:
            # then its D-BAT, and its a list of linear layers
            weights = torch.cat([layer.weight.detach().cpu() for layer in net])
            bias = torch.cat([layer.bias.detach().cpu() for layer in net])
        xs = np.arange(-1.05, 1.05, 0.01)
        plt.xlim([-1.05, 1.05])
        plt.ylim([-1.05, 1.05])

        def plot_linear_fn(xs, slope, color, intercept=0.0):
            ys = slope * xs + intercept
            plt.plot(xs, ys, c=color, alpha=0.8)

        colors = [[252/256,175/256,124/256], [135/256,201/256,195/256]]
        for function_idx in range(len(weights)):
            w_0, w_1 = weights[function_idx][0].item(), weights[function_idx][1].item()
            slope = -w_0 / w_1
            # print(f"slope of head {function_idx}: {slope:.3f}")
            slopes.append(slope)
            _bias = bias[function_idx].item()
            intercept = -_bias / w_1
            plot_linear_fn(xs, slope, color=colors[function_idx], intercept=intercept)

        tr_x, tr_y = training_data
        for g, c, m, s in [(0, "firebrick", "s", 50), (1, "royalblue", "^", 70)]:
            tr_g = tr_x[tr_y.flatten() == g][:30]
            plt.scatter(
                tr_g[:, 0],
                tr_g[:, 1],
                marker=m,
                zorder=10,
                s=s,
                c=c,
                edgecolors="k",
                linewidth=1,
                alpha=0.6
            )

        te_x, te_y = test_data
        te_x1 = te_x[te_y.flatten() == 0][:30]
        unlabeled_class_1 = plt.scatter(
            te_x1[:, 0], te_x1[:, 1], zorder=0, s=40, c="silver", edgecolors="k", linewidth=1, alpha=0.6
        )
        te_x2 = te_x[te_y.flatten() == 1][:30]
        unlabeled_class_2 = plt.scatter(
            te_x2[:, 0], te_x2[:, 1], zorder=0, s=40, c="silver", edgecolors="k", linewidth=1, alpha=0.6
        )

        ax = plt.gca()
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        savefig(f"linear/{exp_name}_{t=}")

    return times, slopes
