#5/8 #seems I forgot to add self.convex*y to f_ast2 so convex = 0 for all previous exp. l540

import os
#from memory_profiler import profile
import numpy as np
import numpy.random as npr

import torch
from torch import nn
from torch.distributions.normal import Normal
from torch.distributions.independent import Independent
from torch.utils.data import TensorDataset
import torch.nn.functional as F


class PositiveLinear(nn.Module):
    def __init__(self, in_features, out_features, use_bias=False):
        super(PositiveLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        # if use_bias:
        #     self.bias = nn.Parameter(torch.Tensor(out_features))
        # else:
        #     self.bias = None
        self.reset_parameters()

    def reset_parameters(self):
        #nn.init.xavier_uniform_(self.weight, 0.01)
        nn.init.xavier_normal_(self.weight)

    def forward(self, input):
        #return nn.functional.linear(input, torch.clamp(self.weight.exp(), min=0.0, max=1e10))
        return nn.functional.linear(input, torch.clamp(self.weight, min=0.0, max=1e10))

class BLNN(nn.Module):
    """
    Bi-Lipschitz Neural network
    """
    def __init__(
        self,
        gamma,
        length_scale,
        args):

        super(BLNN, self).__init__()
        self.contiz_dim = args.contiz_dim
        self.h_dim = args.h_dim
        self.out_dim = args.out_dim
        self.composite = bool(args.composite)
        self.qy_layers = []
        self.brute_force = bool(args.brute_force)
        self.device = torch.device("cuda")
        self.px1_num_hidden_layers = args.num_hidden_layers1
        self.px2_num_hidden_layers = args.num_hidden_layers2
        self.opt = args.optimizer
        self.icnn1_Wy0 = nn.Linear(self.contiz_dim, self.h_dim).to(self.device)
        icnn1_Wy_layers = []
        icnn1_Wz_layers = []
        for i in range(self.px1_num_hidden_layers-1):
            icnn1_Wy_layers.append(nn.Linear(self.contiz_dim, self.h_dim).to(self.device))
            icnn1_Wz_layers.append(PositiveLinear(self.h_dim, self.h_dim).to(self.device))
        icnn1_Wy_layers.append(nn.Linear(self.contiz_dim, 1).to(self.device))
        icnn1_Wz_layers.append(PositiveLinear(self.h_dim, 1).to(self.device))

        self.icnn1_Wy_layers = nn.ModuleList(icnn1_Wy_layers)
        self.icnn1_Wz_layers = nn.ModuleList(icnn1_Wz_layers)

        #self.icnn1_Wy_layers = icnn1_Wy_layers
        #self.icnn1_Wz_layers = icnn1_Wz_layers

        if self.composite == True:


            self.convert = nn.Linear(self.contiz_dim, self.out_dim).to(self.device)
            self.icnn2_Wy0 = nn.Linear(self.out_dim, self.h_dim).to(self.device)
            icnn2_Wy_layers = []
            icnn2_Wz_layers = []
            for i in range(self.px2_num_hidden_layers-1):
                icnn2_Wy_layers.append(nn.Linear(self.out_dim, self.h_dim).to(self.device))
                icnn2_Wz_layers.append(PositiveLinear(self.h_dim, self.h_dim).to(self.device))
            icnn2_Wy_layers.append(nn.Linear(self.out_dim, 1).to(self.device))
            icnn2_Wz_layers.append(PositiveLinear(self.h_dim, 1).to(self.device))

            self.icnn2_Wy_layers = nn.ModuleList(icnn2_Wy_layers)
            self.icnn2_Wz_layers = nn.ModuleList(icnn2_Wz_layers)

        self.gamma = gamma
        self.sigma = length_scale
        if bool(args.learn_params):
            self.convex = nn.Parameter(torch.zeros(1)+args.convex)
            self.smooth = nn.Parameter(torch.zeros(1)+args.smooth)
        else:
            self.smooth= args.smooth
            self.convex= args.convex

        self.W = nn.Parameter((torch.normal(torch.zeros(args.embedding_size, args.num_classes, self.out_dim), 0.05)).to(self.device))

        self.register_buffer('N', (torch.ones(args.num_classes) * 12))
        self.register_buffer('m', torch.normal(torch.zeros(args.embedding_size, args.num_classes), 1))

        self.m = self.m * self.N.unsqueeze(0)
        #self.x1 = torch.ones((128, args.contiz_dim)).cuda()
        #self.x2 = torch.ones((128, args.out_dim)).cuda()
        self.init_points1 = torch.ones((60000,args.contiz_dim)).to(self.device)
        self.init_points2 = torch.ones((60000,args.out_dim)).to(self.device)

    def f1(self, input, with_output=False, create_graph = True):
        with torch.enable_grad():
            #if pr ==  True:
            #    print(input)
            #print("input",input[3])
            #h1 = [[None] for i in range(self.px1_num_hidden_layers + 1)]
            #h2 = [[None] for i in range(self.px1_num_hidden_layers + 1)]
            #h2[0] = torch.pow(nn.ReLU()(self.icnn1_Wy0(input)),4)
            input.requires_grad_(True)
            h2 = nn.Softplus()(self.icnn1_Wy0(input))
            #h2 = nn.ELU()((self.icnn1_Wy0(input)))
            #h2 = torch.pow(nn.ReLU()(self.icnn1_Wy0(input)),2)
            #h2 = nn.ReLU()(self.icnn1_Wy0(input))

            #h1[0] = 4*torch.pow((nn.ReLU()(self.icnn1_Wy0(input))),3).view(-1,self.h_dim,1)*self.icnn1_Wy0.weight
            #h1 = torch.sigmoid((self.icnn1_Wy0(input)).view(-1,self.h_dim,1))*self.icnn1_Wy0.weight
            #h1[0] = torch.minimum(torch.ones(h2[0].size()).to(self.device),torch.exp(self.icnn1_Wy0(input))).view(-1,self.h_dim,1)*self.icnn1_Wy0.weight
            #h1 = 2*(nn.ReLU()(self.icnn1_Wy0(input))).view(-1,self.h_dim,1)*self.icnn1_Wy0.weight
            #h1 = 0.5*torch.mul((torch.sign(self.icnn1_Wy0(input))+1).view(-1,self.h_dim,1),self.icnn1_Wy0.weight)
            #print(h1.size())
            #if pr==True:
            #    print("h1",h1[3])
            for i in range(self.px1_num_hidden_layers):
                #h2[i+1] = torch.pow(nn.ReLU()(self.icnn1_Wz_layers[i](h2[i]) + self.icnn1_Wy_layers[i](input)),4)
                h2_n = nn.Softplus()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))
                #h2_n = nn.ELU()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))
                #h2_n = torch.pow(nn.ReLU()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input)),2)
                #h2_n= nn.ReLU()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))

                #h1[i+1] = 4*torch.pow((nn.ReLU()(self.icnn1_Wz_layers[i](h2[i]) + self.icnn1_Wy_layers[i](input))),3).view(-1,self.icnn1_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn1_Wz_layers[i].weight, min=0.0, max=1e10)@h1[i] + self.icnn1_Wy_layers[i].weight)
                #h1_n = torch.sigmoid(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input)).view(-1,self.icnn1_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn1_Wz_layers[i].weight, min=0.0, max=1e10)@h1 + self.icnn1_Wy_layers[i].weight)
                #h1[i+1] = torch.minimum(torch.ones(h2[i+1].size()).to(self.device),torch.exp(self.icnn1_Wz_layers[i](h2[i]) + self.icnn1_Wy_layers[i](input))).view(-1,self.icnn1_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn1_Wz_layers[i].weight, min=0.0, max=1e10)@h1[i] + self.icnn1_Wy_layers[i].weight)
                #h1_n = 2*(nn.ReLU()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))).view(-1,self.icnn1_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn1_Wz_layers[i].weight, min=0.0, max=1e10)@h1 + self.icnn1_Wy_layers[i].weight)
                #h1_n = 0.5*torch.mul((torch.sign(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))+1).view(-1,self.icnn1_Wy_layers[i].weight.size()[0],1),(torch.matmul(torch.clamp(self.icnn1_Wz_layers[i].weight, min=0.0, max=1e10),h1) + self.icnn1_Wy_layers[i].weight))

                h2 = h2_n
                #h1 = h1_n
                #if pr==True:
                #    print("h1",h1[3])

            #grad_icnn = h1.view(-1,self.contiz_dim) + 1/self.smooth*input
            icnn_output = h2 + 1/(2*self.smooth)*(torch.norm(input,dim=1)**2).view(-1,1)
            grad_icnn = torch.autograd.grad(icnn_output, [input], torch.ones_like(icnn_output), create_graph=create_graph)[0]

            if with_output:
                return grad_icnn, icnn_output
            else:
                return grad_icnn


    def f2(self,input, with_output=False, create_graph= True):
        with torch.enable_grad():
            #h1 = [[None] for i in range(self.px1_num_hidden_layers + 1)]
            #h2 = [[None] for i in range(self.px1_num_hidden_layers + 1)]
            #h2[0] = torch.pow(nn.ReLU()(self.icnn1_Wy0(input)),4)
            input.requires_grad_(True)
            h2 = nn.Softplus()(self.icnn2_Wy0(input))
            #h2 = nn.ELU()((self.icnn2_Wy0(input)))
            #h2 = torch.pow(nn.ReLU()(self.icnn2_Wy0(input)),2)
            #h2 = nn.ReLU()(self.icnn2_Wy0(input))

            #h1[0] = 4*torch.pow((nn.ReLU()(self.icnn1_Wy0(input))),3).view(-1,self.h_dim,1)*self.icnn1_Wy0.weight
            #h1 = torch.sigmoid((self.icnn2_Wy0(input)).view(-1,self.h_dim,1))*self.icnn2_Wy0.weight
            #h1[0] = torch.minimum(torch.ones(h2[0].size()).to(self.device),torch.exp(self.icnn1_Wy0(input))).view(-1,self.h_dim,1)*self.icnn1_Wy0.weight
            #h1 = 2*(nn.ReLU()(self.icnn2_Wy0(input))).view(-1,self.h_dim,1)*self.icnn2_Wy0.weight
            #h1 = 0.5*torch.mul((torch.sign(self.icnn2_Wy0(input))+1).view(-1,self.h_dim,1),self.icnn2_Wy0.weight)
            #print(h1.size())
            for i in range(self.px2_num_hidden_layers):
                #h2[i+1] = torch.pow(nn.ReLU()(self.icnn1_Wz_layers[i](h2[i]) + self.icnn1_Wy_layers[i](input)),4)
                h2_n = nn.Softplus()(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input))
                #h2_n = nn.ELU()(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input))
                #h2_n = torch.pow(nn.ReLU()(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input)),2)
                #h2_n= nn.ReLU()(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input))

                #h1[i+1] = 4*torch.pow((nn.ReLU()(self.icnn1_Wz_layers[i](h2[i]) + self.icnn1_Wy_layers[i](input))),3).view(-1,self.icnn1_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn1_Wz_layers[i].weight, min=0.0, max=1e10)@h1[i] + self.icnn1_Wy_layers[i].weight)
                #h1_n = torch.sigmoid(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input)).view(-1,self.icnn2_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn2_Wz_layers[i].weight, min=0.0, max=1e10)@h1 + self.icnn2_Wy_layers[i].weight)
                #h1[i+1] = torch.minimum(torch.ones(h2[i+1].size()).to(self.device),torch.exp(self.icnn1_Wz_layers[i](h2[i]) + self.icnn1_Wy_layers[i](input))).view(-1,self.icnn1_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn1_Wz_layers[i].weight, min=0.0, max=1e10)@h1[i] + self.icnn1_Wy_layers[i].weight)
                #h1_n = 2*(nn.ReLU()(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input))).view(-1,self.icnn2_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn2_Wz_layers[i].weight, min=0.0, max=1e10)@h1 + self.icnn2_Wy_layers[i].weight)
                #h1_n = 0.5*torch.mul((torch.sign(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input))+1).view(-1,self.icnn2_Wy_layers[i].weight.size()[0],1),(torch.matmul(torch.clamp(self.icnn2_Wz_layers[i].weight, min=0.0, max=1e10),h1) + self.icnn2_Wy_layers[i].weight))

                h2 = h2_n
                #h1 = h1_n

            #grad_icnn = h1.view(-1,self.out_dim) + 1/self.smooth*input
            icnn_output = h2 + 1/(2*self.smooth)*(torch.norm(input,dim=1)**2).view(-1,1)
            grad_icnn = torch.autograd.grad(icnn_output, [input], torch.ones_like(icnn_output), create_graph=create_graph)[0]

            #print(grad_icnn)
            if with_output:
                return grad_icnn, icnn_output
            else:
                return grad_icnn

    def legendre(self, z, id = None, eval=False):
        #with torch.no_grad():
        if self.opt=="GD":
            #print("brute force true")
            #x1 = torch.ones(z.size()).cuda()
            if id == None:
                x1 = torch.ones(z.size()).cuda()
            else:
                x1 = self.init_points1[id]
            #print(len(id))
            #print("x1.size()",x1.size())
            #x1 = self.x1.detach().cuda()
            step = 2*self.smooth
            if eval == True:
                max_it = 5000
            else:
                max_it = 500
            for i in range (max_it):
                #grad, fx =self.f1(x)
                #print("i",i)
                #print("z",z)
                #print("x",x)
                grad = self.f1(x1)
                #print("grad",grad)
                #if pr==True:
                #    print(x[48])
                #    print(grad[-16])
                #print(grad.is_cuda)
                #print(z.is_cuda)
                #print(x.is_cuda)
                x1 = x1 + step/(i+1) * (z-grad)
                #self.x1 = x1
                #print(z-grad)
                if torch.mean(torch.norm(z-grad,dim=1))<0.001:
                    print("i",i)
                    with open(outfilename+"it1_"+self.opt+'.log', 'a') as f:
                        f.write(str(i)+"\n")
                    break
                if i==max_it-1:
                    with open(outfilename+"it1_"+self.opt+'.log', 'a') as f:
                        f.write(str(i)+"\n")
                    print("i",i)
            if id != None:
                self.init_points1[id] = x1.clone().detach()

                #gc.collect()
                #print(torch.cuda.memory_summary(device=None, abbreviated=False))
            #print("1",torch.mean(torch.norm(z-grad,dim=1)))
            if self.composite == True:
                #z2 = torch.matmul(x1+self.convex*z,torch.eye(self.contiz_dim,self.out_dim).cuda())
                z2 = self.convert(x1+self.convex*z)
                x2 = z2-self.convex*z2
                #x2 = torch.ones(z2.size()).cuda()
                '''if id == None:
                    x2 = torch.ones(z2.size()).cuda()
                else:
                    x2 = self.init_points2[id]
                #x2 = self.x2.detach().cuda()
                #step = 2*self.smooth
                for i in range (max_it):
                    #grad, fx =self.f2(x)
                    grad =self.f2(x2)
                    x2 = x2 + step/(i+1) * (z2-grad)
                    #self.x2 = x2
                    if torch.mean(torch.norm(z2-grad,dim=1))<0.001:
                        print("i2",i)
                        with open(outfilename+"it2_"+self.opt+'.log', 'a') as f:
                            f.write(str(i)+"\n")
                        break
                if i==max_it-1:
                    with open(outfilename+"it2_"+self.opt+'.log', 'a') as f:
                        f.write(str(i)+"\n")
                    print("i2",i)
                if id != None:
                    self.init_points2[id] = x2.clone().detach()
                #print("2",torch.mean(torch.norm(z2-grad,dim=1)))'''
        else:
            #print("lbfgs")
            learning_rate = 2*self.smooth
            x = torch.ones(z.size()).cuda()
            if eval == True:
                max_iter = 1000000
            else:
                max_iter = 1000
            tol = 1e-12

            if self.opt=="LBFGS":
                max_iter=1
            #tol = 1e-3
            def closure1():
                with torch.no_grad():
                    # Solves x such that f(x) - y = 0
                    # <=> Solves x such that argmin_x F(x) - <x,y>
                    #F = self.get_potential(x, context)
                    grad,F = self.f1(x, with_output=True)
                    loss = torch.sum(F) - torch.sum(x * z)
                    x.grad = grad - z
                return loss

            if self.opt == "Adam":
                optim = torch.optim.Adam([x],lr=learning_rate,eps=tol)
            elif self.opt == "Adagrad":
                optim = torch.optim.Adagrad([x],lr=learning_rate,eps=tol)
            elif self.opt == "RMSprop":
                optim = torch.optim.RMSprop([x],lr=learning_rate,eps=tol)
            elif self.opt == "LBFGS":
                optim = torch.optim.LBFGS([x], lr=learning_rate, line_search_fn="strong_wolfe", max_iter=500, tolerance_grad=tol, tolerance_change=tol)

            for i in range (max_iter):
                optim.step(closure1)

            if self.composite == True:
                z2 = torch.matmul(x+self.convex*z,torch.eye(self.contiz_dim,self.out_dim).cuda())
                #z2 = self.convert(x1+self.convex*z)

                x2 = torch.ones(z2.size()).cuda()

                def closure2():
                    with torch.no_grad():
                        # Solves x such that f(x) - y = 0
                        # <=> Solves x such that argmin_x F(x) - <x,y>
                        #F = self.get_potential(x, context)
                        grad,F = self.f2(x2, with_output=True)
                        loss = torch.sum(F) - torch.sum(x2 * z2)
                        x2.grad = grad - z2
                    return loss
                if self.opt == "Adam":
                    optim2 = torch.optim.Adam([x2],lr=learning_rate,eps=tol)
                elif self.opt == "Adagrad":
                    optim2 = torch.optim.Adagrad([x2],lr=learning_rate,eps=tol)
                elif self.opt == "RMSprop":
                    optim2 = torch.optim.RMSprop([x2],lr=learning_rate,eps=tol)
                elif self.opt == "LBFGS":
                    optim2 = torch.optim.LBFGS([x2], lr=learning_rate, line_search_fn="strong_wolfe", max_iter=500, tolerance_grad=tol, tolerance_change=tol)

                for i in range (max_iter):
                    optim2.step(closure2)


                #error_new = (self.forward_transform(x, context=context)[0] - y).abs().max().item()
                # if error_new > math.sqrt(tol):
                #     print('inversion error', error_new, flush=True)
                #torch.cuda.empty_cache()
                #gc.collect()
                #print(torch.cuda.memory_summary(device=None, abbreviated=False))
        #print(x2)
        return x1+self.convex*z, x2+self.convex*z2

    def forward(self, x, id=None, extended=False):
        #print(x.size())
        #x = x.view(-1,28*28)
        f_ast1, f_ast2 = self.legendre(x, id)

        #print("f_ast2",f_ast2)
        #print("z",z)
        y_pred = self.bilinear(z)
        if extended == True:
            return z,f_ast1, f_ast2, y_pred
        else:
            return y_pred

class PICNN(nn.Module):
    def __init__(self, dimx, dimy, dimh, num_hidden_layers, smooth, convex):
        super(PICNN, self).__init__()
        # with data dependent init

        self.dimx = dimx
        self.dimy = dimy
        self.dimh = dimh

        self.out_dim = 101
        self.composite = True
        self.device = torch.device("cuda")
        self.px1_num_hidden_layers = num_hidden_layers
        self.px2_num_hidden_layers = num_hidden_layers
        #self.opt = args.optimizer

        self.smooth = smooth
        self.convex = convex

        self.act = nn.Softplus()

        # data path
        Wzs = list()
        for _ in range(num_hidden_layers - 1):
            Wzs.append(PositiveLinear(dimh, dimh))
        Wzs.append(PositiveLinear(dimh, 1))
        self.Wzs = torch.nn.ModuleList(Wzs)

        # skip data
        Wzus = list()
        for _ in range(num_hidden_layers - 1):
            Wzus.append(nn.Linear(dimh, dimh, bias=True))
        Wzus.append(nn.Linear(dimh, 1, bias=True))
        self.Wzus = torch.nn.ModuleList(Wzus)

        Wys = list()
        Wys.append(nn.Linear(dimy, dimh, bias=False))
        for _ in range(num_hidden_layers - 1):
            Wys.append(nn.Linear(dimy, dimh, bias=False))
        Wys.append(nn.Linear(dimy, 1, bias=False))
        self.Wys = torch.nn.ModuleList(Wys)

        Wyus = list()
        Wyus.append(nn.Linear(dimx, dimy, bias=True))
        for _ in range(num_hidden_layers - 1):
            Wyus.append(nn.Linear(dimh, dimy, bias=True))
        Wyus.append(nn.Linear(dimh, dimy, bias=True))
        self.Wyus = torch.nn.ModuleList(Wyus)

        Wus = list()
        Wus.append(nn.Linear(dimx, dimh, bias=True))
        for _ in range(num_hidden_layers - 1):
            Wus.append(nn.Linear(dimh, dimh, bias=True))
        Wus.append(nn.Linear(dimh, 1, bias=True))
        self.Wus = torch.nn.ModuleList(Wus)

        Wuus = list()
        Wuus.append(nn.Linear(dimx, dimh, bias=True))
        for _ in range(num_hidden_layers - 1):
            Wuus.append(nn.Linear(dimh, dimh, bias=True))
        self.Wuus = torch.nn.ModuleList(Wuus)

        self.convert = nn.Linear(self.dimy, self.out_dim).requires_grad_(False)

        self.init_points1 = torch.ones((600000,dimy)).to(self.device)
        self.init_points2 = torch.ones((600000,1)).to(self.device)

    def f(self, x, y,pr=False):
        if pr == True:
            print("x",x)
            print("y",y)
        with torch.enable_grad():

            y.requires_grad_(True)

            prevZ, prevU = None, x
            #print("1",prevU)
            #print("2",self.Wuus[0](prevU))
            #print("3",self.act(self.Wuus[0](prevU)))
            u = self.act(self.Wuus[0](prevU))
            yu_u = self.Wyus[0](prevU)
            z_yu = self.Wys[0](y * yu_u)
            z_u = self.Wus[0](prevU)
            z = self.act(z_yu+z_u)

            prevZ = z
            prevU = u
            if pr == True:
                print("u",u)
                print("yu_u",yu_u)
                print("z_yu",z_yu)
                print("z_u",z_u)
                print("z",z)
            #    print("prevZ", prevZ)
            #    print("prevU", prevU)
            #c = self.act_c(self.actnormc(self.Wcs[0](c)))
            for Wz, Wzu, Wy, Wyu, Wu, Wuu in zip(
                    self.Wzs[:-1], self.Wzus[:-1],
                    self.Wys[1:-1], self.Wyus[1:-1], self.Wus[1:-1],
                    self.Wuus[1:-1]):
                #print("hi")
                u = self.act(Wuu(prevU))

                zu_u = self.act(Wzu(prevU))
                z_zu = Wz(prevZ * zu_u)

                yu_u = Wyu(prevU)
                z_yu = Wy(y * yu_u)

                z_u = Wu(prevU)

                z = self.act(z_zu+z_yu+z_u)

                prevU = u
                prevZ = z
                #if pr == True:
                #    print("prevZ", prevZ)
                #    print("prevU", prevU)

            zu_u = self.act(self.Wzus[-1](prevU))
            z_zu = self.Wzs[-1](prevZ * zu_u)

            yu_u = self.Wyus[-1](prevU)
            z_yu = self.Wys[-1](y * yu_u)

            z_u = self.Wus[-1](prevU)

            z = (z_zu+z_yu+z_u)


            icnn_output = z + 1/(2*self.smooth)*(torch.norm(y,dim=1)**2).view(-1,1)
            grad_icnn = torch.autograd.grad(icnn_output, [y], torch.ones_like(icnn_output), create_graph=True)[0]
        #if pr == True:
        #    print("icnn_output",icnn_output)
        #    print("grad_icnn",grad_icnn)
        return grad_icnn

    def legendre(self, x, y, id = None, pr=False, eval=False):
        #with torch.no_grad():
        #print("brute force true")
        #x1 = torch.ones(z.size()).cuda()
        if id == None:
            y1 = torch.ones(y.size()).cuda()
        else:
            y1 = self.init_points1[id]
        #print(len(id))
        #print("x1.size()",x1.size())
        #x1 = self.x1.detach().cuda()
        step = 2*self.smooth
        if eval == True:
            max_it = 5000
        else:
            #max_it = 500
            max_it = 500
        #print("x",x)
        #print("y1",y1)
        for i in range (max_it):
            #grad, fx =self.f1(x)
            #print("i",i)
            #print("z",z)
            #print("x",x)
            grad = self.f(x,y1, pr=pr)
            #print("grad",grad)
            #if pr==True:
            #    print(x[48])
            #    print(grad[-16])
            #print(grad.is_cuda)
            #print(z.is_cuda)
            #print(x.is_cuda)
            #print("y1bef",y1)
            if pr ==True:
                print("y1bef",y1)

            y1 = y1 + step/(i+1) * (y-grad)
            if pr ==True:
            #print("y1aft",y1)
                print("i",i)
                #print("grad", y-grad)
                #print("y1", y1)
            #self.x1 = x1
            #print(torch.mean(torch.norm(y-grad,dim=1)))
                print((torch.norm(y-grad,dim=1)))
            if torch.mean(torch.norm(y-grad,dim=1))<0.001:
                #print("i",i)
                #with open(outfilename+"it1_"+self.opt+'.log', 'a') as f:
                #    f.write(str(i)+"\n")
                break
            #if i==max_it-1:
                #with open(outfilename+"it1_"+self.opt+'.log', 'a') as f:
                #    f.write(str(i)+"\n")
                #print("i",i)
        if id != None:
            self.init_points1[id] = y1.clone().detach()

            #gc.collect()
            #print(torch.cuda.memory_summary(device=None, abbreviated=False))
        #print("1",torch.mean(torch.norm(z-grad,dim=1)))
        if self.composite == True:
            #z2 = torch.matmul(x1+self.convex*z,torch.eye(self.contiz_dim,self.out_dim).cuda())
            #z2 = self.convert(y1+self.convex*y)
            #y2 = z2-self.convex*z2
            #print(y1.size())

            y2 = torch.mean(y1+self.convex*y, dim = 1)
            #print("y2",y2)
        #return y1+self.convex*y, y2.view(-1,1)#+self.convex*z2

            #for CIFAR101
            #z2 = self.convert(y1+self.convex*y)
            #y2 = z2

        return y1+self.convex*y, y2.view(-1,1)#+self.convex*z2
        #return y1+self.convex*y, y2 # For CIFAR101

    def forward(self, x, y, id=None, pr=False ):
        #print(x.size())
        #x = x.view(-1,28*28)
        f_ast1, f_ast2 = self.legendre(x, y, id, pr)

        #print("f_ast2",f_ast2)
        #print("z",z)
        return f_ast2

class PICNN_multiclass(PICNN):
    def __init__(self, dimx, dimy, dimh, num_hidden_layers, smooth, convex, class_nb):
        super().__init__(dimx, dimy, dimh, num_hidden_layers, smooth, convex)

        self.dimx = dimx
        self.dimx = dimx
        self.dimh = dimh
        self.num_hidden_layers = num_hidden_layers
        self.smooth = smooth
        self.convex = convex
        self.class_nb = class_nb
        # with data dependent init
        self.modules_list = nn.ModuleList()
        for _ in range(class_nb):
            (self.modules_list).append(PICNN(self.dimx, self.dimy, self.dimh, self.num_hidden_layers, self.smooth, self.convex))
    def forward(self,x,y_idx, id=None, pr=False):
        #cl = None
        idx_list= []
        for i in y_idx:
            idx = torch.arange(len(x[0])).tolist()
            idx.pop(i)
            idx_list.append(idx)
        cl = torch.cat([module(x[:,idx_list[i]],x[:,y_idx[i]].view(-1,1),id) for i, module in enumerate(self.modules_list)], dim=1)
        #print(cl.size())
        '''for i, module in zip(y_idx, self.modules_list):
            if cl == None:
                idx = torch.arange(len(x[0])).tolist()
                idx.pop(i)
                ##print(x.size())
                #print(x[:,idx].size(),x[:,i].size())
                cl = module(x[:,idx],x[:,i].view(-1,1),id)
                #print(cl)
            else:
                if i == 2 and pr == True:
                    cl = torch.concat((cl,module(x[:,idx],x[:,i].view(-1,1),id , pr=pr)), 1)
                else:
                    cl = torch.concat((cl,module(x[:,idx],x[:,i].view(-1,1),id)), 1)'''

                #cl = torch.concat((cl,module(x[:,idx],x[:,i].view(-1,1),id)), 1)
                #print(cl.size())
            #print("cl",i, cl)
        return torch.mean(cl, dim=1).view(-1,1)


class PICNN_class101(PICNN):
    def __init__(self, dimx, dimy, dimh, num_hidden_layers, smooth, convex):
        super().__init__(dimx, dimy, dimh, num_hidden_layers, smooth, convex)

        self.dimx = dimx
        self.dimx = dimx
        self.dimh = dimh
        self.num_hidden_layers = num_hidden_layers
        self.smooth = smooth
        self.convex = convex
        # with data dependent init
        self.modules_list = nn.ModuleList()
        for _ in range(101):
            (self.modules_list).append(PICNN(self.dimx, self.dimy, self.dimh, self.num_hidden_layers, self.smooth, self.convex))
    def forward(self,x,y,id=None, extended=False):
        cl = None
        for i, module in enumerate(self.modules_list):
            if cl == None:
                cl = module(x,y,id)
                #print(cl)
            else:
                cl = torch.concat((cl,module(x,y,id)), 1)
                #print(cl.size())
        return cl
