import torch
from torchvision import datasets, transforms

from os import path

def w(x):
    if torch.cuda.is_available():
        return x.cuda()
    else:
        return x
data_path = 'data/'
if not path.exists(data_path):

    data_path = 'data/'
dataset_train = datasets.MNIST(data_path, train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor()
                           ]))
dataset_test = datasets.MNIST(data_path, train=False,
                           download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor()
                           ]))


class taskLoader(object):
    def __init__(self, problem_size = 1000, hl=1, units=20, actvRelu=False, test=False, n_data=None):
        if n_data is None:
            n_data = len(dataset_test) if test else len(dataset_train)
        torch.random.manual_seed(0)
        self.problem_size = problem_size
        self.hl = hl
        self.units = units
        self.actvRelu = actvRelu
        self.test = test
        self.n_data = n_data
        self.n = n_data/problem_size
        self.num = 0


    def __iter__(self):
        return self

    # Python 3 compatibility
    def __next__(self):
        return self.next()

    def next(self):
        if self.num < self.n:
            st = self.num*self.problem_size
            self.num = self.num+1
            return w(MlpForMNIST(hl=self.hl,
                               units=self.units,
                               actvRelu=self.actvRelu,
                               n_p=range(st,st+self.problem_size-1),
                               test=False))

        else:
            raise StopIteration()

def load_training(problem_size = 1000, hl=1, units=20, actvRelu=False, test=False, n_data=None):
    if n_data is None:
        n_data = len(dataset_test) if test else len(dataset_train)
    torch.random.manual_seed(0)
    return [w(MlpForMNIST(hl=hl, units=units, actvRelu=actvRelu, n_p=range(st,st+problem_size-1), test=False))
            for st in range(0, n_data, problem_size)]


class MlpForMNIST(torch.nn.Module):
    def __init__(self, hl=1, units=20, actvRelu=False, res=28, n_p=range(100), test=False):
        super(MlpForMNIST, self).__init__()
        self.loss_fun = torch.nn.CrossEntropyLoss()
        self.loss_value = None
        self.hl = hl
        self.units = units
        self.actvRelu = actvRelu
        self.res = res
        self.inputLayer = w(torch.nn.Linear(res ** 2, units))
        self.hiddenLayer = w(torch.nn.ModuleList(([torch.nn.Linear(units, units) for _ in range(hl)])))
        self.outputLayer = w(torch.nn.Linear(units, 10))
        # Creating non-leaf parameters
        regularNN = False
        if not regularNN:
            aux = self.inputLayer.weight.detach()
            del self.inputLayer.weight
            self.inputLayer.weight = aux.requires_grad_()
            self.w = [self.inputLayer]
            for i in range(hl):
                aux = self.hiddenLayer[i].weight.detach()
                del self.hiddenLayer[i].weight
                self.hiddenLayer[i].weight = aux.requires_grad_()
                self.w.append(self.hiddenLayer[i])
            aux = self.outputLayer.weight.detach()
            del self.outputLayer.weight
            self.outputLayer.weight = aux.requires_grad_()
            self.w.append(self.outputLayer)
        else:
            self.w = [self.inputLayer]
            for i in range(hl):
                self.w.append(self.hiddenLayer[i])
            self.w.append(self.outputLayer)

        self.n = self.getPars().nelement()
        self.n_p = n_p
        if test:
            dataset = dataset_test
        else:
            dataset = dataset_train
        self.set = []
        self.target = []
        for i in n_p:
            in_data = dataset[i % len(dataset)]
            self.set.append(in_data[0].view(1,-1).double())
            self.target.append(in_data[1])
        self.set = w(torch.cat(self.set,0))
        self.target = w(torch.tensor(self.target))
        self.x0 = self.getPars().clone()
        #print(n_p)

    def forward(self, x__, y__):
        n_b = x__.shape[0]
        x__ = self.inputLayer(x__).sigmoid()
        for i in range(self.hl):
            x__ = self.hiddenLayer[i](x__).sigmoid()
        out = self.outputLayer(x__)
        return self.loss_fun(out, y__)

    def f(self):
        self.loss_value = self.forward(self.set, self.target)
        return self.loss_value

    def fx(self,x_=None):
        if x_ is None:
            return self.loss_value
        pars = self.getPars()
        self.setPars(x_)
        val = self.forward(self.set, self.target)
        self.setPars(pars)
        return val

    def df(self):
        self.zero_grad()
        gradp = torch.autograd.grad(self.loss_value, [x.weight for x in self.w], retain_graph=True)
        for x, g in zip(self.w, gradp):
            x.weight.grad = g
        return torch.cat([x.view(-1) for x in gradp])

    def dfx(self, x_):
        self.zero_grad()
        pars = self.getPars()
        self.setPars(x_)
        isTraining = torch.is_grad_enabled()
        if not isTraining:
            torch.set_grad_enabled(True)
        aux = self.forward(self.set, self.target)
        gradp = torch.autograd.grad(aux, [x.weight for x in self.w], retain_graph=True)
        if not isTraining:
            torch.set_grad_enabled(False)
        self.setPars(pars)
        return torch.cat([x.view(-1).detach() for x in gradp])


    def getPars(self):
        return torch.cat([x.weight.view(-1).clone() for x in self.w])

    def setPars(self, newPars):
        offset = 0
        for p in self.w:
            p.weight = newPars[offset:offset + p.weight.data.nelement()].view(p.weight.data.shape).requires_grad_()
            offset += p.weight.data.nelement()

    def resetPars(self, newPars):
        offset = 0
        for p in self.w:
            p.weight = newPars[offset:offset+p.weight.data.nelement()].clone().view(p.weight.data.shape).detach().requires_grad_()
            offset += p.weight.data.nelement()

    def reset(self):
        self.resetPars(self.x0)
        self.loss_value = None

    def detachState(self):
        self.resetPars(self.getPars().detach().requires_grad_())

    def soft_parameters(self):
        return [x.weight for x in self.w]


if __name__ == "__main__":
    x = MlpForMNIST()
    x.f()
    g = x.df()
    g_a = x.dfx(x.getPars() - 0.3 * g)
    x.setPars(x.getPars()-0.3*g)
    x.f()
    g_b = x.df()
