import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
import cv2
import matplotlib.pyplot as plt
transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root="./data",
                                         train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size=101, shuffle=True, num_workers=0)
test_set = torchvision.datasets.CIFAR10(root="./data",
                                        train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size=100, shuffle=False, num_workers=0)

#Value_Q=torch.tensor([0.5,0.4,0.3,0.6,0.7])
#Value_Q=torch.tensor([0.65,0.6,0.55])
Value_Q=torch.tensor([0.5])
dim=3072
num1=3000
num2=3000
for qqi in range(10):
    for qqj in range(10):
        if qqi<qqj:
          yy1=qqi
          yy2=qqj
          uz1=torch.zeros(num1,dim)
          uz2=torch.zeros(num2,dim)
          jo=0
          jo1=0
          for x,y in train_set:
             if y==yy1 and jo<num1:
               x=x.view(1,3072)
               uz1[jo,:]+=x.view(dim)
               jo+=1

             if y==yy2 and jo1<num2:
                x = x.view(1, 3072)
                uz2[jo1, :] += x.view(dim)
                jo1 += 1

          chang=0
          tuji = torch.zeros(num1,num2, dim)
          bias = torch.zeros(num1,num2)
          julijuzhen=torch.zeros(num1,num2)
          for i in range(num1):  # ju li ju zhen
                 for j in range(num2):
                   julijuzhen[i,j]+=torch.norm(uz1[i] - uz2[j])
          c1=torch.zeros(num1)-1
          for w in range(10000000000000000000000000):
  #print('sheng yu:',torch.sum(abs(c1)))
                  jbjs=0
                  for i in range(num1):
                     if c1[i]<0:
                         jbjs+=1
                         break
                  if jbjs==0:
                      break
                  dxx=0
                  dxd=torch.zeros(1)+1000000000000000000
                  for i in range (num1):   # zhao zui xiao dian
                              if c1[i] < 0:
                                for j in range(num2):
                                  diss=julijuzhen[i,j]
                                  if diss<dxd:
                                     dxd=diss*0+diss
                                     dxx=dxx*0+i
                  x=uz1[dxx]      # zui xiao zhi dian
                  dis=torch.zeros(num2)
                  for i in range(num2):
                            dis[i]=torch.norm(uz2[i]-x)
                            dis1,dis2=torch.sort(dis)
                  kuan=0
                  xzd=torch.zeros(1)-1000
                  for gx in range(len(Value_Q)):   # xun huan gamma
     # print(w,gx)
                        vq=Value_Q[gx]
                        yixia=torch.zeros(num2)-1
                        kuan_z=0
                        tuji_z =torch.zeros(num2,dim)
                        bias_z = torch.zeros(num2)
                        for ii in range(100000000):    # kai shi zhao tu ji
                                  ks = -1
                                  for i in range(num2):
                                        if yixia[int(dis2[i])]==-1:
                                           ks=int(dis2[i])
                                           break
                                  if ks==-1:
                                       break
                                  if ks>-1:
                                     w1=x-uz2[ks,:]
                                     b1=-torch.sum(w1*(vq*x+(1-vq)*uz2[ks,:]))
                                     tuji_z[kuan_z, :] += w1
                                     bias_z[kuan_z] += b1
                                     kuan_z += 1
                                  for i in range(num2):
                                        k1=torch.sum(w1*uz2[int(dis2[i]),:])+b1
                                        if k1<0:
                                          yixia[int(dis2[i])]=1
                        linshijigeshu=torch.zeros(1)
                        for i in range(num1):   # kan kan bao han ji ge
                            if c1[i]==-1:
                               k1=torch.mm(tuji_z,uz1[i].view(dim,1)).view(num2)+bias_z.view(num2)
                               if torch.sum(abs(k1-abs(k1)))==0:
                                   linshijigeshu+=1
                        tbs=linshijigeshu/kuan_z
                        if tbs>xzd:
                             tuji[chang,:,:]=tuji_z
                             bias[chang,:]=bias_z
                             kuan=kuan_z
                             xzd=xzd*0+tbs
                  for i in range(num1):  # qu dian
                     if c1[i] == -1:
                       k1=torch.mm(tuji[chang,:,:],uz1[i].view(dim,1)).view(-1)+bias[chang,:].view(-1)
                       if torch.sum(abs(k1 - abs(k1))) == 0:
                           c1[i] = 0
                  chang+=1
          #print(yy1,yy2,chang)

          zq=0
          zq1=0
          for x,y in test_set:
             x = x.view(1, 3072)
             if y==yy1:
                 x=x.view(-1,1)
                 for i in range(chang):
                      k=torch.mm(tuji[i,:,:],x).view(-1)+bias[i,:].view(-1)
                      if torch.sum(abs(k-abs(k)))==0:
                            zq+=1
                            break
             if y==yy2:
                 x=x.view(-1,1)
                 pz=0
                 for i in range(chang):
                    k=torch.mm(tuji[i,:,:],x).view(-1)+bias[i,:].view(-1)
                    if torch.sum(abs(k-abs(k)))==0:
                       pz=1
                       break
                 zq1+=(1-pz)
          jg=(zq+zq1)/2000
          #print('jie guo:',jg)
          print('labels are', yy1,'and',yy2)
          print('The accuracy on testset is', jg)