import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
import cv2
import matplotlib.pyplot as plt
train_dataset = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor(),download=True)

#Gamma=torch.tensor([0.5,0.4,0.3,0.6,0.7])
#gamma=torch.tensor([0.5,0.4,0.6])
Gamma=torch.tensor([0.5])
yyinput=0
yy2=4
dim=784
num1=500
num2=500
uz1=torch.zeros(num1,dim)
uz2=torch.zeros(num2,dim)
jo=0
jo1=0
for x,y in train_dataset:
    if y==yyinput and jo<num1:
        x=x.view(1,784)
        uz1[jo,:]+=x.view(dim)
        jo+=1

    if y==yy2 and jo1<num2:
       x = x.view(1, 784)
       uz2[jo1, :] += x.view(dim)
       jo1 += 1

print('Begin To Find Convex Polyhedron')
def mad(k):
    k1,k2=torch.sort(k)
    if k1[0]<0:
        return 0
    if torch.sum(abs(k-abs(k)))==0:
        return 1
    return 0


chang=0
tuji = torch.zeros(num1,num2, dim)
bias = torch.zeros(num1,num2)
julijuzhen=torch.zeros(num1,num2)
for i in range(num1):  # ju li ju zhen
    for j in range(num2):
            julijuzhen[i,j]+=torch.norm(uz1[i] - uz2[j])

#print('kai shi:')
c1=torch.zeros(num1)-1
for w in range(10000000000000000000000000):
  #print('sheng yu:',torch.sum(abs(c1)))
  jbjs=0
  for i in range(num1):
      if c1[i]<0:
          jbjs+=1
          break
  if jbjs==0:
      break
  dxx=0
  dxd=torch.zeros(1)+1000000000000000000
  for i in range (num1):   # zhao zui xiao dian
      if c1[i] < 0:
         for j in range(num2):
             diss=julijuzhen[i,j]
             if diss<dxd:
                dxd=diss*0+diss
                dxx=dxx*0+i
  x=uz1[dxx]      # zui xiao zhi dian
  dis=torch.zeros(num2)
  for i in range(num2):
      dis[i]=torch.norm(uz2[i]-x)
  dis1,dis2=torch.sort(dis)
  kuan=0
  xzd=torch.zeros(1)-1000000000000
  for gx in range(len(Gamma)):   # xun huan gamma
     # print(w,gx)
      vq=Gamma[gx]
      yixia=torch.zeros(num2)-1
      kuan_z=0
      tuji_z =torch.zeros(num2,dim)
      bias_z = torch.zeros(num2)
      for ii in range(100000000):    # kai shi zhao tu ji
         ks = -1
         for i in range(num2):
           if yixia[int(dis2[i])]==-1:
             ks=int(dis2[i])
             break
         if ks==-1:
           break
         if ks>-1:
           w1=x-uz2[ks,:]
           b1=-torch.sum(w1*(vq*x+(1-vq)*uz2[ks,:]))
           tuji_z[kuan_z, :] += w1
           bias_z[kuan_z] += b1
           kuan_z += 1
         for i in range(num2):
           k1=torch.sum(w1*uz2[int(dis2[i]),:])+b1
           if k1<0:
                yixia[int(dis2[i])]=1
      linshijigeshu=torch.zeros(1)
      for i in range(num1):   # kan kan bao han ji ge
          if c1[i]==-1:
             k1=torch.mm(tuji_z,uz1[i].view(dim,1)).view(num2)+bias_z.view(num2)
             if torch.sum(abs(k1-abs(k1)))==0:
                 linshijigeshu+=1
      tbs=linshijigeshu/kuan_z
      if tbs>xzd:
          tuji[chang,:,:]=tuji_z
          bias[chang,:]=bias_z
          kuan=kuan_z
          xzd=xzd*0+tbs
  for i in range(num1):  # qu dian
      if c1[i] == -1:
          k1=torch.mm(tuji[chang,:,:],uz1[i].view(dim,1)).view(-1)+bias[chang,:].view(-1)
          if torch.sum(abs(k1 - abs(k1))) == 0:
              c1[i] = 0
  chang+=1
  print('This is the', chang, '-th convex polyhedron', 'which has', kuan, 'boundaries')

print('Convex Polyhedron Search Completed')
print('Begin To Construct Memorization Network')
u=torch.zeros(1)+100
u1=tuji
b1=bias
for j in range(num2):
   for i in range(chang):
         kx = torch.mm(u1[i, :, :], uz2[j,:].view(-1,1)).view(-1) + b1[i, :].view(-1)
         kx1 = torch.sum(F.relu(kx) - kx)
         if kx1<u:
             u=u*0+kx1



class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.cov16 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False)
        self.bn16 = nn.BatchNorm2d(512)

    def forward(self,x):
        x1=x.view(-1,1)
        k=torch.zeros(1)
        for i in range(chang):
            kx=torch.mm(u1[i,:,:],x1).view(-1)+b1[i,:].view(-1)
            kx1=torch.sum(F.relu(kx)-kx)
          #  print(kx1)
            k+=F.relu(1-kx1*2/u)
        return  k-0.5


net=Net()
print('Memorization Network Construction Completed')
print('Begin To Measure The Accuracy On Testset')
zq=0
zq1=0
he=0
hej=0

for x,y in test_dataset:
    if y==yyinput:
        he+=1
        o1=net(x)
        if torch.sign(o1)>0:
            zq+=1

    if y==yy2:
        hej+=1
        o1 = net(x)
        if torch.sign(o1) < 0:
            zq1 += 1
print('The Accuracy on Testset:',(zq+zq1)/(hej+he))
