from os import dup2
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.nn.functional as F
#from utils import train
import torchvision.models as models
import time
import matplotlib.pyplot as plt
import cvxpy as cvx
import scipy.io as scio
import scipy as spy
time_start=time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

#parameter
numtar=6
oplabel0=8
oplabel1=9
feature_train=np.load('./train_feature_1024_cifar10.npy')
feature_test=np.load('./test_feature_1024_cifar10.npy')
label_train=np.load('./train_label_cifar10.npy')
label_test=np.load('./test_label_cifar10.npy')

#functions
class Net_f(nn.Module):
    def __init__(self):
        super(Net_f, self).__init__()
#        self.fc1 = nn.Linear(1024,32)
        self.fc2 = nn.Linear(1024,10)


    def forward(self,x):
#        out=F.relu(self.fc1(x))
#        out=self.fc2(out)
        out=self.fc2(x)
        return out       

 

class Net_g(nn.Module):
    def __init__(self,num_class=2, dim=10):
        super(Net_g, self).__init__()

        self.fc=nn.Linear(num_class, dim)

    def forward(self,x):
        out=self.fc(x)

        return out

def corr(f,g):
    k = torch.mean(torch.sum(f*g,1))
    return k
    
def cov_trace(f,g):
    cov_f = torch.mm(torch.t(f),f) / (f.size()[0]-1.)
    cov_g = torch.mm(torch.t(g),g) / (g.size()[0]-1.)
    return torch.trace(torch.mm(cov_f, cov_g))


#preprocess
data0=feature_train[label_train==0]
data1=feature_train[label_train==1]
index1=np.random.choice(5000,numtar,replace=False)
index2=np.random.choice(5000,numtar,replace=False)
targetset=np.vstack((data0[index1],data1[index2]))
targetlabel=np.append(np.zeros(numtar),np.ones(numtar))

sourcedata0=feature_train[label_train==oplabel0]
sourcedata1=feature_train[label_train==oplabel1]
index1=np.random.choice(5000,1000,replace=False)
index2=np.random.choice(5000,1000,replace=False)
sourceset=np.vstack((sourcedata0[index1],sourcedata1[index2]))
sourcelabel=np.append(np.zeros(1000),np.ones(1000))

testdata0=feature_test[label_test==0][0:1000]
testdata1=feature_test[label_test==1][0:1000]
testset=np.vstack((testdata0,testdata1))
testlabel=np.append(np.zeros(1000),np.ones(1000))
#nn param
lr=0.0002
epoch=400
ind=0
dim=10
pyrt=1/np.sqrt(2)
model_f = Net_f().to(device)
model_gt = Net_g().to(device)
model_gs = Net_g().to(device)
optimizer_fg = torch.optim.Adam(list(model_f.parameters())+list(model_gt.parameters())+list(model_gs.parameters()),lr=lr)
losslist=[]
acclist=[0]
ddlist=[]
alpha=[0.5,0.5]
#alphadd=[1.,0.]
alphalist=[]

samples_tar=torch.from_numpy(targetset)
nt=samples_tar.size()[0]
labels_tar=torch.from_numpy(targetlabel)
labels_one_hot_tar = F.one_hot(labels_tar.long())

samples_sour=torch.from_numpy(sourceset)
ns=samples_sour.size()[0]
labels_sour=torch.from_numpy(sourcelabel)
labels_one_hot_sour= F.one_hot(labels_sour.long())

samples_test=torch.from_numpy(testset)
#labels_test=torch.from_numpy(testlabel)
#labels_one_hot_test= F.one_hot(labels_test.long())

for i in range(epoch):
    model_f.train()
    model_gt.train()
    model_gs.train()
    
    f_tar=model_f(Variable(samples_tar).float().to(device))
    g_tar=model_gt(Variable(labels_one_hot_tar).float().to(device))
    ff_tar = f_tar - torch.mean(f_tar,0)
    gg_tar = g_tar - torch.mean(g_tar,0)
    f_sour=model_f(Variable(samples_sour).float().to(device))
    g_sour=model_gs(Variable(labels_one_hot_sour).float().to(device))
    ff_sour = f_sour - torch.mean(f_sour,0)
    gg_sour = g_sour - torch.mean(g_sour,0)


#
    optimizer_fg.zero_grad()
    
    loss=(-2)*alpha[0]*corr(ff_tar,gg_tar)
    loss+=(-2)*alpha[1]*corr(ff_sour,gg_sour)
    loss+=alpha[0]*cov_trace(ff_tar,gg_tar)
    loss+=alpha[1]*cov_trace(ff_sour,gg_sour)
    losslist.append(loss.item())
    loss.backward()
    optimizer_fg.step()
    
#------acc
    model_f.eval()
    model_gt.eval()
    model_gs.eval()

    fc = model_f(Variable(samples_tar).float().to(device)).data.cpu().numpy()
    f_mean = np.sum(fc,axis=0)/fc.shape[0]
    fc=fc-f_mean
    fcs=model_f(Variable(samples_sour).float().to(device)).data.cpu().numpy()
    f_means = np.sum(fcs,axis=0)/fcs.shape[0]
    fcs=fcs-f_means
    labellist = torch.Tensor(np.eye(2))
    gc = model_gt(Variable(labellist).to(device)).data.cpu().numpy()
    gce = np.sum(gc,axis=0)/gc.shape[0]
    gcp = gc-gce

    gcs = model_gs(Variable(labellist).to(device)).data.cpu().numpy()
    gcse = np.sum(gcs,axis=0)/gcs.shape[0]
    gcs = gcs-gcse
    
    lambdat=np.dot(fc.T,fc)/fc.shape[0]
    lambdas=np.dot(fcs.T,fcs)/fcs.shape[0]
    wt, vt = np.linalg.eig(lambdat)
    ws, vs = np.linalg.eig(lambdas)
    lambdath=vt.dot(np.diag(np.sqrt(wt))).dot(vt.T)
    lambdash=vs.dot(np.diag(np.sqrt(ws))).dot(vs.T)

    regdim=dim
    dd=1/2*np.sum(np.square(lambdath.dot(gcp.T)-lambdash.dot(gcs.T)))
    ddlist.append(dd)
    rdd=dd/(1/nt/2+1/ns/2)
    bes1=spy.special.iv(dim,rdd)
    bes2=spy.special.iv(dim-1,rdd)
#    aa=(1/ns)/(1/nt+1/ns)+(1/nt)/(1/nt+1/ns)*(2*regdim-rdd)/(2*regdim)
    aa=(1/ns)/(1/nt+1/ns)+(1/nt)/(1/nt+1/ns)*bes1/bes2
    alphalist.append(aa)
    gcp=gcp*aa+(1-aa)*(np.dot(np.linalg.inv(lambdath),lambdash).dot(gcs.T)).T


    fc=model_f(Variable(samples_test).float().to(device)).data.cpu().numpy()
    fcp=fc-f_mean
    fgp=np.dot(fcp,gcp.T)
    acc = (np.argmax(fgp, axis = 1) == testlabel).sum()
    total = len(samples_test)


    acc=acc/total
    acclist.append(acc)

print(acclist)
print(max(acclist))
print(alphalist)
print(ddlist)