from os import dup2
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.nn.functional as F
#from utils import train
import torchvision.models as models
import time
import matplotlib.pyplot as plt
import cvxpy as cvx
import scipy.io as scio
import scipy as spy
time_start=time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

#parameter
numtar=100
oplabel0=2
oplabel1=3
feature_train=np.load('./train_feature_1024_cifar10.npy')
feature_test=np.load('./test_feature_1024_cifar10.npy')
label_train=np.load('./train_label_cifar10.npy')
label_test=np.load('./test_label_cifar10.npy')

#functions
class Net_f(nn.Module):
    def __init__(self):
        super(Net_f, self).__init__()
#        self.fc1 = nn.Linear(1024,32)
        self.fc2 = nn.Linear(1024,10)


    def forward(self,x):
#        out=F.relu(self.fc1(x))
#        out=self.fc2(out)
        out=self.fc2(x)
        return out       

 

class Net_g(nn.Module):
    def __init__(self,num_class=2, dim=10):
        super(Net_g, self).__init__()

        self.fc=nn.Linear(num_class, dim)

    def forward(self,x):
        out=self.fc(x)

        return out

def corr(f,g):
    k = torch.mean(torch.sum(f*g,1))
    return k
    
def cov_trace(f,g):
    cov_f = torch.mm(torch.t(f),f) / (f.size()[0]-1.)
    cov_g = torch.mm(torch.t(g),g) / (g.size()[0]-1.)
    return torch.trace(torch.mm(cov_f, cov_g))


#preprocess
data0=feature_train[label_train==0]
data1=feature_train[label_train==1]
index1=np.random.choice(5000,numtar,replace=False)
index2=np.random.choice(5000,numtar,replace=False)
targetset=np.vstack((data0[index1],data1[index2]))
targetlabel=np.append(np.zeros(numtar),np.ones(numtar))

sourcedata0=feature_train[label_train==oplabel0]
sourcedata1=feature_train[label_train==oplabel1]
index1=np.random.choice(5000,1000,replace=False)
index2=np.random.choice(5000,1000,replace=False)
sourceset=np.vstack((sourcedata0[index1],sourcedata1[index2]))
sourcelabel=np.append(np.zeros(1000),np.ones(1000))

testdata0=feature_test[label_test==0][0:1000]
testdata1=feature_test[label_test==1][0:1000]
testset=np.vstack((testdata0,testdata1))
testlabel=np.append(np.zeros(1000),np.ones(1000))
#nn param
lr=0.0001
epoch=300
ind=0
dim=10
model_f = Net_f().to(device)
model_g = Net_g().to(device)
optimizer_fg = torch.optim.Adam(list(model_f.parameters())+list(model_g.parameters()),lr=lr)
losslist=[]
acclist=[0]
alpha=[0.8,0.2]
alphadd=[1.,0.]
alphalist=[]

samples_tar=torch.from_numpy(targetset)
nt=samples_tar.size()[0]
labels_tar=torch.from_numpy(targetlabel)
labels_one_hot_tar = F.one_hot(labels_tar.long())

samples_sour=torch.from_numpy(sourceset)
ns=samples_sour.size()[0]
labels_sour=torch.from_numpy(sourcelabel)
labels_one_hot_sour= F.one_hot(labels_sour.long())

samples_test=torch.from_numpy(testset)
#labels_test=torch.from_numpy(testlabel)
#labels_one_hot_test= F.one_hot(labels_test.long())

for i in range(epoch):
    model_f.train()
    model_g.train()
    
    f_tar=model_f(Variable(samples_tar).float().to(device))
    g_tar=model_g(Variable(labels_one_hot_tar).float().to(device))
    ff_tar = f_tar - torch.mean(f_tar,0)
    gg_tar = g_tar - torch.mean(g_tar,0)
    ff_sour=model_f(Variable(samples_sour).float().to(device))-torch.mean(f_tar,0)
    gg_sour=model_g(Variable(labels_one_hot_sour).float().to(device))- torch.mean(g_tar,0)

#-------renew alpha
#----------have to convert to python data first
    ft=ff_tar.detach().numpy()
    fs=ff_sour.detach().numpy()
    lambdaf=ft.T.dot(ft)/nt
    dd=0
    dd+=(1/2*(np.mean(ft[0:numtar],0)-np.mean(fs[0:1000],0)).reshape(1,dim).dot(np.linalg.inv(lambdaf)).dot((np.mean(ft[0:numtar],0)-np.mean(fs[0:1000],0)).reshape(dim,1)))[0][0]
    dd+=(1/2*(np.mean(ft[numtar:2*numtar],0)-np.mean(fs[1000:2000],0)).reshape(1,dim).dot(np.linalg.inv(lambdaf)).dot((np.mean(ft[numtar:2*numtar],0)-np.mean(fs[1000:2000],0)).reshape(dim,1)))[0][0]   
#    bes1=spy.special.iv(dim,dd/(1/nt+1/ns))
#    bes2=spy.special.iv(dim-1,dd/(1/nt+1/ns))
    rdd=dd/(1/nt+1/ns)
    regdim=dim
#    alphat=(1/ns)/(1/nt+1/ns)+(1/nt)/(1/nt+1/ns)*bes1/bes2
    alphat=(1/ns)/(1/nt+1/ns)+(1/nt)/(1/nt+1/ns)*(rdd-regdim)/rdd
    alphalist.append(alphat)
    alpha=[alphat,1-alphat]
#
    optimizer_fg.zero_grad()
    
    loss=(-2)*alphadd[0]*corr(ff_tar,gg_tar)
    loss+=(-2)*alphadd[1]*corr(ff_sour,gg_sour)
    loss+=cov_trace(ff_tar,gg_tar)
    losslist.append(loss.item())
    loss.backward()
    optimizer_fg.step()
    
#------acc
    model_f.eval()
    model_g.eval()
    fc = model_f(Variable(samples_tar).float().to(device)).data.cpu().numpy()
    f_mean = np.sum(fc,axis=0)/fc.shape[0]
    labellist = torch.Tensor(np.eye(2))
    gc = model_g(Variable(labellist).to(device)).data.cpu().numpy()
    gce = np.sum(gc,axis=0)/gc.shape[0]
    gcp = gc-gce

    fc=model_f(Variable(samples_test).float().to(device)).data.cpu().numpy()
    fcp=fc-f_mean
    fgp=np.dot(fcp,gcp.T)
    acc = (np.argmax(fgp, axis = 1) == testlabel).sum()
    total = len(samples_test)


    acc=acc/total
    acclist.append(acc)

print(acclist)
print(max(acclist))
print(alphalist)