import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
plt.rc('font', family="Arial")
plt.rc('axes', axisbelow=True)
plt.rcParams['font.size'] = '14'
colors = [plt.get_cmap('Set1')(i) for i in range(9)]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def path(args):
    b = ''
    if args.bias:
        b = 'wbias'
    return '{}_L{}{}_{}_init{}_lr{}'.format(args.data, args.depth, b, args.loss, args.init, args.lr)


def vis_classify(args, data, network, t=None, ax=None, plot_data=False):
    x = np.linspace(-1.2, 1.2, 50)
    x1, x2 = np.meshgrid(x, x)
    X = np.stack((x1.ravel(), x2.ravel()), -1)
    Y_hat = network(torch.tensor(X).float().to(device))
    Y_hat = Y_hat.cpu().detach().numpy().reshape((50, 50))
    thresh = np.max([np.max(Y_hat), np.max(-Y_hat)])
    print(network(torch.tensor(data['x']).float().to(device)))

    if ax is None:
        fig, ax = plt.subplots(figsize=(3, 2.8))
        ax.axis('equal')
        ax.set_xlim((-1.2, 1.2))
        ax.set_ylim((-1.2, 1.2))
        ax.set_xticks([-1, 0, 1])
        ax.set_yticks([-1, 0, 1])
        plt.tight_layout()
        ax.pcolor(x1, x2, Y_hat, vmax=thresh, vmin=-thresh, cmap='coolwarm', rasterized=True)
        if plot_data:
            ax.scatter(data['x'][:,0], data['x'][:,1], c=data['y']+2, cmap='Greys')
    else:
        if plot_data:
            ax.scatter(data['x'][:,0], data['x'][:,1], c=data['y']+2, cmap='Greys', zorder=5)
        ax = ax.pcolor(x1, x2, Y_hat, vmax=thresh, vmin=-thresh, cmap='coolwarm', rasterized=True)

    x1_bool = (data['y'][:,0] == 1)
    x1 = data['x'][x1_bool, :]
    x2 = data['x'][~x1_bool, :]
    plt.scatter(x1[:,0], x1[:,1], s=100, linewidths=1.6, marker='+', c='k')
    plt.scatter(x2[:,0], x2[:,1], s=100, linewidths=1.6, marker='_', c='k')
    # plt.savefig(path(args)+'_vis.svg')
    return ax


def vis_weight(W, t):
    L = len(W)
    plt.figure(figsize=(17, 6))
    gs = matplotlib.gridspec.GridSpec(1, 4, width_ratios=[1.1, 1, 1, 0.4])
    for l in range(L):
        W_l = W[l].cpu().detach().numpy()
        plt.subplot(gs[L-l-1])
        cap = np.max(np.abs(W_l))
        plt.imshow(W_l, cmap='RdGy', vmin=-cap, vmax=cap)
        plt.title("$W_{}$".format(l+1))
        plt.xticks([]) 
        plt.yticks([])
        cbar = plt.colorbar()
        formatter = matplotlib.ticker.ScalarFormatter(useMathText=True)
        formatter.set_powerlimits((-1, 1))
        formatter.set_scientific(True)
        cbar.ax.yaxis.set_major_formatter(formatter)
    if t is not None:
        plt.suptitle("Epoch={}".format(t))
    plt.tight_layout(pad=0.5)
    # plt.savefig('frame/weight/{:04d}.jpg'.format(t), dpi=300)
    plt.show()


def plot_training(args, data, results):
    plt.figure(figsize=(4, 3))    
    plt.plot(results['Ls'], linewidth=2, c='k', label="Loss")
    plt.xlabel(r"$t$")
    plt.ylabel("Loss")
    plt.xlim((0, args.epoch))
    plt.tight_layout(pad=0.5)