import os
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
import datasets
import nasspace
import torch
import torch.nn as nn
from tqdm import tqdm

def get_batch_jacobian(net, x, target, device, args=None):
    net.zero_grad()
    x.requires_grad_(True)
    y, _ = net(x)
    y.backward(torch.ones_like(y))
    jacob = x.grad.detach()
    return jacob, target.detach()

def plot_hist(jacob, ax, colour):
    xx =  jacob.reshape(jacob.size(0), -1).cpu().numpy()
    corrs = np.corrcoef(xx)
    N, bins, patches = ax.hist(corrs.flatten(), bins=100, color=colour)
    #for i in range(50,75):
    #    patches[i].set_facecolor('r')

def decide_plot(acc, plt_cts, num_rows, boundaries=[60., 70., 80., 90.]):
    
    if args.nasspace == 'nasbench101':
        acc = acc * 100
    if acc < boundaries[0]:
        plt_col = 0
        accrange = f'< {boundaries[0]}%'
    elif acc < boundaries[1]:
        plt_col = 1
        accrange = f'[{boundaries[0]}% , {boundaries[1]}%)'
    elif acc < boundaries[2]:
        plt_col = 2
        accrange = f'[{boundaries[1]}% , {boundaries[2]}%)'
    elif acc < boundaries[3]:
        accrange = f'[{boundaries[2]}% , {boundaries[3]}%)'
        plt_col = 3
    else:
        accrange = f'>= {boundaries[3]}%'
        plt_col = 4

    can_plot = False
    plt_row = 0
    if plt_cts[plt_col] < num_rows:
        can_plot = True
        plt_row = plt_cts[plt_col]
        plt_cts[plt_col] += 1

    return can_plot, plt_row, plt_col, accrange



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Plot histograms of correlation matrix')
    parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder')
    parser.add_argument('--api_loc', default='NAS-Bench-201-v1_1-096897.pth',
                    type=str, help='path to API')
    parser.add_argument('--arch_start', default=0, type=int)
    parser.add_argument('--arch_end', default=15625, type=int)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--GPU', default='0', type=str)
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--dataset', default='cifar10', type=str)
    parser.add_argument('--trainval', action='store_true')
    parser.add_argument('--rows', default=5, type=int)
    parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
    parser.add_argument('--augtype', default='gaussnoise', type=str, help='which perturbations to use')
    parser.add_argument('--nasspace', default='nasbench201', type=str)
    parser.add_argument('--sigma', default=0.01, type=float)
    parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
    parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
    parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
    parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU

    # Reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    searchspace = nasspace.get_search_space(args)
    criterion = nn.CrossEntropyLoss()
    train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)

    scores = []
    accs = []

    plot_shape = (args.rows, 5)
    num_plots = plot_shape[0]*plot_shape[1]
    fig, axes = plt.subplots(*plot_shape, sharex=True, figsize=(7.,8.))
    plt_cts = [0 for i in range(plot_shape[1])]

    plotted_so_far = 0

    colours = ['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B']

    archs = np.arange(len(searchspace))

    strs = []
    random.shuffle(archs)
    for arch in tqdm(archs):
        acc = searchspace.get_accuracy(searchspace[arch],'x-valid')
        #print(acc)
        boundaries = [60., 70., 80., 90.]
        can_plt, row, col, accrange = decide_plot(acc, plt_cts, plot_shape[0], boundaries)
        if can_plt:
            try:
                network = searchspace.get_network(searchspace[arch])
                data_iterator = iter(train_loader)
                x, target = next(data_iterator)
                x, target = x.to(device), target.to(device)
                network = network.to(device)
                jacobs, labels = get_batch_jacobian(network, x, target, device, args)
                axes[row, col].axis('off')

                plot_hist(jacobs, axes[row, col], colours[col])
                #print(f'Plotted {plotted_so_far}/{num_plots}: {acc}')
                #plotted_so_far += 1
                print(plt_cts)
                if row == 0:
                    axes[row, col].set_title(f'{accrange}')

                if row + 1 == plot_shape[0]:
                    axes[row, col].axis('on')
                    plt.setp(axes[row, col].get_xticklabels(), fontsize=12)
                    axes[row, col].spines["top"].set_visible(False)
                    axes[row, col].spines["right"].set_visible(False)
                    axes[row, col].spines["left"].set_visible(False)
                    axes[row, col].set_yticks([])

                if sum(plt_cts) == num_plots:
                    plt.tight_layout()
                    plt.savefig(f'results/{args.nasspace}_histograms_cifar10val_batch{args.batch_size}_aug{args.augtype}_rows{args.rows}.pdf')
                    plt.show()
                    break
            except Exception as e:
                print(e)
                plt_cts[col] -= 1
                continue
