import torch.optim as optim
import os
import time
import argparse
import random
import numpy as np
import re
from config_utils import load_config
import torch
import torchvision
from torch.distributions import Categorical
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from jacob_scores import _scores
from nasbench import api
from nas_101_api.model import Network
from nas_101_api.model_spec import ModelSpec
#import termplotlib as tpl
import matplotlib.pyplot as plt



def get_batch_jacobian(net, data_loader, to, device, acc, args, randlabel=False, dydx=False):
    crit = torch.nn.CrossEntropyLoss(reduction='none')
    net.zero_grad()
    data_iterator = iter(data_loader)
    x, target = next(data_iterator)

    if randlabel:
        target = torch.randint_like(target, 10)

    target = target.repeat(to)
    minx = x.min()
    maxx = x.max()

    x = x.to(device)
    target = target.to(device)

    x_batch = x.shape[0]
    x_shape = x.shape[1:]

    x = x.unsqueeze(1)
    x = x.repeat(1, to, *(1,)*len(x.shape[2:]))
    x.requires_grad_(True)

    tmp_shape = x.shape
    y, ints = net(x.reshape(-1, *tmp_shape[2:]))

    '''
    if not dydx:
        y = crit(y, target)
    else:
        if args.oneweirdtrick:
            y = y[:,0:(y.shape[1]//2)].sum(1)- y[:,(y.shape[1]//2):-1].sum(1)
        else:
    '''
    y = y.mean(1)

    y_shape = y.shape[1:]
    y = y.reshape(x_batch, to, to)
    input_val = torch.eye(to).reshape(1, to, to).repeat(x_batch, 1, 1).to(device)
    y.backward(input_val)
    jacob = x.grad.reshape(x_batch, *y_shape, *x_shape).detach().cpu()
    return jacob, target.detach().cpu()

def get_batch_y(net, data_loader, to, device, acc, randlabel=False):
    data_iterator = iter(train_loader)
    x, target = next(data_iterator)
    x = x.cuda()
    _, y = network(x)

    return y, target.detach().cpu()

parser = argparse.ArgumentParser(description='Calc mask Fish matrix')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack')
parser.add_argument('--num_labels', default=10, type=int, help='#classes')
parser.add_argument('--save_loc', default='fishmat', type=str, help='folder to save results')
parser.add_argument('--save_string', default='jacobian101', type=str, help='save string')
#parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--arch_start', default=0, type=int)
parser.add_argument('--arch_end', default=15625, type=int)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--batchnormstep', action='store_true')
parser.add_argument('--randlabel', action='store_true')
parser.add_argument('--fake_data', action='store_true')
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--oneweirdtrick', action='store_true')
parser.add_argument('--score', default='None', type=str, help='Score to evaluate. If "None" this will save the jacobians')
parser.add_argument('--dataset', default='cifar10', help='cifar10, cifar100, imagenet-16-120') # note - this doesn't work for imagenet

args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU

from bench_models import get_cell_based_tiny_net
import torch
import torch.nn as nn

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

import torchvision.transforms as transforms
from datasets import get_datasets
#from nas_201_api import NASBench201API as API

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
THE_START = time.time()
#api = API('NAS-Bench-201-v1_0-e61699.pth')  # Change this to your path

NASBENCH_TFRECORD = './data/nasbench_only108.tfrecord'
nasbench = api.NASBench(NASBENCH_TFRECORD)
#api = API(args.api_loc)
ARCH_START = args.arch_start
ARCH_END = args.arch_end
os.makedirs(args.save_loc, exist_ok=True)

criterion = nn.CrossEntropyLoss()
train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_loc, 0, fake=args.fake_data)


from PIL import ImageFilter

class RandomKernel(object):
    def __init__(self, radius=2):
        self.radius = radius
                                    
    def __call__(self, img):
        ws = np.abs(np.random.randn(9))
        ws = ws/(ws.sum())
        one = np.zeros(9)
        one[5] = 1.
        ws = one + ws

        return img.filter(ImageFilter.Kernel((3, 3), ws))
                                                    
    def __repr__(self):
        return self.__class__.__name__ + '(radius={0})'.format(self.radius)


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.001):
        self.std = std
        self.mean = mean
                                    
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
                                                    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)




augtype = '_cutout'
augtype = '_colourjitter'
augtype = '_gaussblur'
augtype = '_randomkernel'
augtype = '_gaussnoise'
runs = '_20'

if args.trainval:
    cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
    train_split, valid_split = cifar_split.train, cifar_split.valid
    if augtype == '_gaussnoise':
        train_data.transform.transforms.append(AddGaussianNoise(std=0.001))
    elif augtype == '_randomkernel':
        print('randomkernel')
        train_data.transform.transforms.insert(2, RandomKernel())
    elif augtype == '_colourjitter':
        print('colourjitter')
        train_data.transform.transforms.insert(2, torchvision.transforms.ColorJitter(0.01, 0.01, 0.01, 0.01))
    elif augtype == '_cutout':
        train_data.transform.transforms.append(torchvision.transforms.RandomErasing(p=0.9, scale=(0.02, 0.04)))
    train_data.transform.transforms = train_data.transform.transforms[2:]
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
                                               num_workers=0, pin_memory=True, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(train_split), 256))
                                               #num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
else:
    if augtype == '_gaussnoise':
        train_data.transform.transforms.append(AddGaussianNoise(std=0.001))
    elif augtype == '_randomkernel':
        print('randomkernel')
        train_data.transform.transforms.insert(2, RandomKernel())
    elif augtype == '_colourjitter':
        train_data.transform.transforms.insert(2, torchvision.transforms.ColorJitter(0.01, 0.01, 0.01, 0.01))
        print('colourjitter')
    elif augtype == '_cutout':
        train_data.transform.transforms.append(torchvision.transforms.RandomErasing(p=0.9, scale=(0.02, 0.04)))
    train_data.transform.transforms = train_data.transform.transforms[2:]
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, #shuffle=True,
                                               num_workers=0, pin_memory=True, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(range(len(train_data))), 256))
                                               #num_workers=0, pin_memory=True)

if args.dataset == 'ImageNet16-120':
    imagenet_split = load_config('configs/nas-benchmark/ImageNet16-120-split.txt', None, None)
    train_split, valid_split = imagenet_split.train, imagenet_split.valid
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, num_workers=0, pin_memory=True)

def extend_sub(module):
    for attr_str in dir(module):
        target_attr = getattr(module, attr_str)
        if isinstance(target_attr, nn.Conv2d):
            print(target_attr)
            setattr(module, attr_str, extend(target_attr))
        if isinstance(target_attr, nn.Linear):
            print(target_attr)
            setattr(module, attr_str, extend(target_attr))
    named_children = list(module.named_children())
    for name, ch in named_children:
        if isinstance(ch, nn.Conv2d):
            module._modules[name] = extend(ch)
        extend_sub(ch)

def train_batch_once(net, data_loader, device):
    data_iterator = iter(data_loader)
    x, target = next(data_iterator)
    x, target = x.to(device), target.to(device)
    _, y = net(x)

scores = {}
score_names = ['correiglowerquartile','corrabsdevfromzero','maxabsupperquartile','gausslaplace', 'gaussentropycov', 'gaussentropycorr','entropy','mi','correig2','gaussmicov','gaussmicorr','coveiglowerquartile']

score_names = ['evidenceapprox']
score_names = ['corrdistintegral']

num_nets = len(nasbench.hash_iterator())
print('########################################')
print(f'num_nets {num_nets}')

for score_name in score_names:
    scores[score_name] = np.zeros(num_nets)

accs = []
#for arch in range(ARCH_START, ARCH_END):
for ii, unique_hash in enumerate(nasbench.hash_iterator()):
    matrix = nasbench.fixed_statistics[unique_hash]['module_adjacency']
    operations = nasbench.fixed_statistics[unique_hash]['module_operations']
    #print(matrix)
    #print(operations)
    #matrix = [[int(m) for m in matrix[i, :]] for i in range(matrix.shape[0])]
    spec = ModelSpec(matrix, operations)
    arch = nasbench._hash_spec(spec)
    data = nasbench.query(spec)

    acc = data['test_accuracy']    
    num_params = data['trainable_parameters']
    accs.append(acc)

    if args.batchnormstep:
        train_batch_once(network, train_loader, device)

    # if score == mi then get_batch_y instead

    for score in score_names:
        try:
            ss = []
            for i in range(20 if runs == '_20' else 1):
                network = Network(spec, args)
                network = network.to(device)
                jacobians, labels = get_batch_jacobian(network, train_loader, 1, device, acc, args, args.randlabel, dydx=True)
                s =  _scores[score](jacobians, labels)
                ss.append(s)
                del jacobians
                del network
            scores[score][ii] = np.max(ss)
        except Exception as e:
            print(e)
            print('nan')
            scores[score][ii] = np.nan


    #del state
    if ii % 100 == 0 and ii > 1:
        for metric, vals in scores.items():
            #fig = tpl.figure()
            #fig = plt.figure()
            #ax = fig.add_subplot(111)
            #ax.scatter(accs[:ii], scores[metric][:ii])
            #plt.show()
            #fig.show()
            filename = f"{args.save_loc}/{args.save_string}{augtype}{runs}_{metric}_{args.dataset}_{args.batchnormstep}_{args.randlabel}_{args.seed}_{args.trainval}_{args.batch_size}_{args.oneweirdtrick}"
            np.save(filename, vals)

for metric, vals in scores.items():
    filename = f"{args.save_loc}/{args.save_string}{augtype}{runs}_{metric}_{args.dataset}_{args.batchnormstep}_{args.randlabel}_{args.seed}_{args.trainval}_{args.batch_size}_{args.oneweirdtrick}"
    np.save(filename, vals)
THE_END = time.time()
print(THE_END - THE_START)
