import argparse
import datasets
import os
import torch
import torch.nn as nn
import random
import numpy as np
from tqdm import tqdm
from scipy.stats import rankdata



class JacobianDataset:
    def __init__(self, datafilename, accsfilename):
        self.corrs = np.load(datafilename)
        self.accs = np.load(accsfilename)
        inds = self.corrs.sum(axis=1) > 0.
        self.corrs = self.corrs[inds, :]
        self.accs = self.accs[inds]
        inds = self.accs > 0.5
        self.corrs = self.corrs[inds, :]
        self.accs = self.accs[inds]
        self.corrs = self.corrs[:100001, :]
        self.accs = self.accs[:100001]
        #self.accs = rankdata(self.accs)/self.accs.shape[0]
        
        
    def __getitem__(self, index):
        return self.corrs[index], self.accs[index]
    def __len__(self):
        return self.corrs.shape[0]



class Predictor(nn.Module):
    def __init__(self):
        super(Predictor, self).__init__()
        self.a = torch.nn.parameter.Parameter(torch.Tensor(1))
        self.b = torch.nn.parameter.Parameter(torch.Tensor(1))
        #self.beta = torch.distributions.Beta(torch.nn.functional.softplus(self.a) + 1., torch.nn.functional.softplus(self.b) + 1.)
        self.lin1 = nn.Linear(1, 128)
        self.relu1 = nn.ReLU(inplace=True)
        self.drop1 = nn.Dropout()
        self.lin2 = nn.Linear(128, 128)
        self.relu2 = nn.ReLU(inplace=True)
        self.drop2 = nn.Dropout()
        self.lin3 = nn.Linear(128, 1)

        #self.conv1 = nn.Conv1d(1, 32, 1)
        #self.bn1   = nn.BatchNorm1d(32)
        #self.relu1 = nn.ReLU(inplace=True)
        #self.conv2 = nn.Conv1d(32, 32, 1)
        #self.bn2   = nn.BatchNorm1d(32)
        #self.relu2 = nn.ReLU(inplace=True)
        #self.avg   = nn.AdaptiveAvgPool1d(1)
        #self.conv3 = nn.Conv1d(32, 32, 1)
        #self.bn3   = nn.BatchNorm1d(32)
        #self.relu3 = nn.ReLU(inplace=True)
        #self.lin = nn.Linear(32, 1)


    def forward(self, x):
        #logps = self.beta.log_prob((x+1.)/2.)
        x = (x+1.)/2.
        self.a.data.clamp(0.1, 10.)
        self.b.data.clamp(0.1, 10.)
        a, b = torch.nn.functional.softplus(self.a), torch.nn.functional.softplus(self.b)
        #a, b = 4., 2.
        #print('')
        #print('')
        #print(self.a)
        #print(self.b)
        #print(a)
        #print(b)
        
        ps = (x**(a))*(1.-x**(a+1.))**(b)
        #print(x)
        #print(ps)
        score = ps.mean(dim=1, keepdim=True)
        pred = score
        #pred = self.lin1(score)
        #pred = self.relu1(pred)
        #pred = self.drop1(pred)
        #pred = self.lin2(pred)
        #pred = self.relu2(pred)
        #pred = self.drop2(pred)
        #pred = self.lin3(pred)



        #out = self.conv1(x)
        #out = self.bn1(out)
        #out = self.relu1(out)
        #out = self.conv2(out)
        #out = self.bn2(out)
        #out = self.relu2(out)
        #out = self.avg(out)
        #out = self.conv3(out)
        #out = self.bn3(out)
        #out = self.relu3(out)
        #out = out.view(x.size(0), -1)
        #out = self.lin(out)
        return pred


parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../fimflam/NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='corrdistintegral0_025', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='gaussnoise', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.001, type=str, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)




device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
network = Predictor()
network = network.to(device)
optimizer = torch.optim.SGD(network.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0002)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 90, 120, 150], gamma=0.1)
#criterion = nn.CrossEntropyLoss().to(device)
#criterion = nn.MSELoss().to(device)
criterion = nn.MarginRankingLoss(0.0).to(device)

os.makedirs(args.save_loc, exist_ok=True)
datafilename = f'{args.save_loc}/{args.save_string}_correlationmatrix_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.seed}.npy'
accsfilename = f'{args.save_loc}/{args.save_string}_correlationmatrixaccs_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.seed}.npy'
corrdata = JacobianDataset(datafilename, accsfilename)
inds = list(range(len(corrdata)))
random.shuffle(inds)
train_split = inds[:5000]
test_split = inds[5000:10000]


trainloader = torch.utils.data.DataLoader(corrdata, batch_size=args.batch_size,
                                          num_workers=10, pin_memory=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split))
testloader = torch.utils.data.DataLoader(corrdata, batch_size=args.batch_size,
                                          num_workers=10, pin_memory=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(test_split))




testpred_filename = f'{args.save_loc}/{args.save_string}_testpredictions_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.seed}'
testacc_filename = f'{args.save_loc}/{args.save_string}_testaccs_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.seed}'
trainpred_filename = f'{args.save_loc}/{args.save_string}_trainpredictions_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.seed}'
trainacc_filename = f'{args.save_loc}/{args.save_string}_trainaccs_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.seed}'

epochs = 100
with tqdm(total=epochs, bar_format="Epoch {postfix[0]}\ttrain_mse: {postfix[1][train_mse]:>8.2g}\ttest_mse: {postfix[1][test_mse]:>8.2g}", postfix=["Epoch", dict(epoch=0, train_mse=0, test_mse=0)]) as t:
    for epoch in range(epochs):
        preds = None
        accs = None
        for data in trainloader:
            xs = data[0]
            lbl = data[1]
            xs = xs.to(device)
            ls = lbl.to(device)
        
            optimizer.zero_grad()
            ys = network(xs)

            inds1 = torch.argsort(ls, descending=False)
            inds2 = torch.argsort(ls, descending=True)
            zs = torch.ones_like(ls)
            zs[ls[inds1] < ls[inds2]] = -1
            #print(ys[inds1])
            #print(ys[inds2])
            #print(zs)
            #print(ls[inds1])
            loss = criterion(ys[inds1], ys[inds2], zs)

            #loss = criterion(ys, ls)
            loss.backward()
            optimizer.step()
            if preds is None:
                preds = ys.detach().cpu().numpy().flatten()
                accs = ls.detach().cpu().numpy().flatten()
            else:
                preds = np.concatenate([preds, ys.detach().cpu().numpy().flatten()])
                accs = np.concatenate([accs, ls.detach().cpu().numpy().flatten()])
        np.save(trainpred_filename, preds)    
        np.save(trainacc_filename, accs)
        t.postfix[0] = epoch
        t.postfix[1]["train_mse"] = ((preds - accs)**2).mean()
    
        scheduler.step()
    
        preds = None
        accs = None
        for data in testloader:
            xs = data[0]
            lbl = data[1]
            xs = xs.to(device)
            ls = lbl.to(device)
            ys = network(xs)
            if preds is None:
                preds = ys.detach().cpu().numpy().flatten()
                accs = ls.detach().cpu().numpy().flatten()
            else:
                preds = np.concatenate([preds, ys.detach().cpu().numpy().flatten()])
                accs = np.concatenate([accs, ls.detach().cpu().numpy().flatten()])
        np.save(testpred_filename, preds)    
        np.save(testacc_filename, accs)    
        t.postfix[1]["test_mse"] = ((preds - accs)**2).mean()
        t.update()



