from os import dup2
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.nn.functional as F
#from utils import train
import torchvision.models as models
import time
import matplotlib.pyplot as plt
import cvxpy as cvx
import scipy.io as scio
import scipy as spy
time_start=time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

x1=np.load('./../feature/feature_a_4096.npy')
y1=np.load('./../feature/label_a.npy')
x2=np.load('./../feature/feature_d_4096.npy')
y2=np.load('./../feature/label_d.npy')
x3=np.load('./../feature/feature_w_4096.npy')
y3=np.load('./../feature/label_w.npy')

#trainset_target
index=np.random.choice(x2[y2==0].shape[0],3,replace=False)
refx=x2[y2==0][index]
refy=[0,0,0]
for i in range(1,31):
    index=np.random.choice(x2[y2==i].shape[0],3,replace=False)
    refx=np.vstack((refx,x2[y2==i][index]))
    refy=np.append(refy,[i,i,i])
refy=refy.astype(int)

#trainset_test
index=np.random.choice(x1[y1==0].shape[0],20,replace=False)
sourcex=x1[y1==0][index]
sourcey=0*np.ones(20)
for i in range(1,31):
    index=np.random.choice(x1[y1==i].shape[0],20,replace=False)
    sourcex=np.vstack((sourcex,x1[y1==i][index]))
    sourcey=np.append(sourcey,i*np.ones(20))
sourcey=sourcey.astype(int)

#feature---------------
class Net_f(nn.Module):
    def __init__(self):
        super(Net_f, self).__init__()
        self.fc1 = nn.Linear(4096,1024)
        self.fc2 = nn.Linear(1024,64)

    def forward(self,x):
        out=F.relu(self.fc1(x))
        out=self.fc2(out)
        return out       


class Net_g(nn.Module):
    def __init__(self,num_class=31, dim=64):
        super(Net_g, self).__init__()

        self.fc=nn.Linear(num_class, dim)

    def forward(self,x):
        out=self.fc(x)

        return out

def corr(f,g):
    k = torch.mean(torch.sum(f*g,1))
    return k
    
def cov_trace(f,g):
    cov_f = torch.mm(torch.t(f),f) / (f.size()[0]-1.)
    cov_g = torch.mm(torch.t(g),g) / (g.size()[0]-1.)
    return torch.trace(torch.mm(cov_f, cov_g))

def neg_hscore(f,g):
    f0 = f - torch.mean(f,0)
    g0 = g - torch.mean(g,0)
    corr = torch.mean(torch.sum(f0*g0,1))
    cov_f = torch.mm(torch.t(f0),f0) / (f0.size()[0]-1.)
    cov_g = torch.mm(torch.t(g0),g0) / (g0.size()[0]-1.)
    return - corr + torch.trace(torch.mm(cov_f, cov_g)) / 2.

lr=0.0001
epoch=200
ind=0
dim=64
model_f = Net_f().to(device)
model_g = Net_g().to(device)
optimizer_fg = torch.optim.Adam(list(model_f.parameters())+list(model_g.parameters()),lr=lr)
losslist=[]
acclist=[0]
alpha=[0.8,0.2]
alphalist=[]
samples_ref=torch.from_numpy(refx)
nt=samples_ref.size()[0]
labels_ref=torch.from_numpy(refy)
labels_one_hot_ref = torch.zeros(len(labels_ref), 31).scatter_(1, labels_ref.view(-1,1), 1)
samples_trans=torch.from_numpy(sourcex)
ns=samples_trans.size()[0]
labels_trans=torch.from_numpy(sourcey)
labels_one_hot_trans= torch.zeros(len(labels_trans), 31).scatter_(1, labels_trans.view(-1,1), 1)

for i in range(epoch):
    model_f.train()
    model_g.train()
    
    f_ref=model_f(Variable(samples_ref).float().to(device))
    g_ref=model_g(Variable(labels_one_hot_ref).float().to(device))
    f0_ref = f_ref - torch.mean(f_ref,0)
    g0_ref = g_ref - torch.mean(g_ref,0)
#    f_trans=model_f(Variable(samples_trans).float().to(device))
#    g_trans=model_g(Variable(labels_one_hot_trans).float().to(device))
#    f_trans=f_trans-torch.mean(f_trans,0)
#    g_trans=g_trans-torch.mean(g_trans,0)
    f_trans=model_f(Variable(samples_trans).float().to(device))-torch.mean(f_ref,0)
    g_trans=model_g(Variable(labels_one_hot_trans).float().to(device))- torch.mean(g_ref,0)

#-------renew alpha
#----------have to convert to python data first
    ft=f0_ref.detach().numpy()
    fs=f_trans.detach().numpy()
    lambdaf=ft.T.dot(ft)/nt
    dd=0
    for i in range(31):
        dd+=(1/31*(np.mean(ft[3*i:3*i+3],0)-np.mean(fs[8*i:8*i+8],0)).reshape(1,dim).dot(np.linalg.inv(lambdaf)).dot((np.mean(ft[3*i:3*i+3],0)-np.mean(fs[8*i:8*i+8],0)).reshape(dim,1)))[0][0]
#    bes1=scipy.special.iv(dim*31/2,dd/(1/nt+1/ns))
#    bes2=scipy.special.iv(dim*31/2-1,dd/(1/nt+1/ns))
    rdd=dd/(1/nt+1/ns)
    regdim=dim*31/2
    alphat=(1/ns)/(1/nt+1/ns)+(1/nt)/(1/nt+1/ns)*(rdd-regdim)/rdd
    alphalist.append(alphat)
    alpha=[alphat,1-alphat]
#
    optimizer_fg.zero_grad()
    
    loss=(-2)*alpha[0]*corr(f0_ref,g0_ref)
    loss+=(-2)*alpha[1]*corr(f_trans,g_trans)
    loss+=cov_trace(f0_ref,g0_ref)
    losslist.append(loss.item())
    loss.backward()
    optimizer_fg.step()
    
#------acc
    model_f.eval()
    model_g.eval()
    fc = model_f(Variable(samples_ref).float().to(device)).data.cpu().numpy()
    f_mean = np.sum(fc,axis=0)/fc.shape[0]
    labellist = torch.Tensor(np.eye(31))
    gc = model_g(Variable(labellist).to(device)).data.cpu().numpy()
    gce = np.sum(gc,axis=0)/gc.shape[0]
    gcp = gc-gce

    samples_test=torch.from_numpy(x2)
    labels_test = y2
    fc=model_f(Variable(samples_test).float().to(device)).data.cpu().numpy()
    fcp=fc-f_mean
    fgp=np.dot(fcp,gcp.T)
    acc = (np.argmax(fgp, axis = 1) == labels_test).sum()
    total = len(samples_test)

    samples_test=torch.from_numpy(refx)
    labels_test = refy
    fc=model_f(Variable(samples_test).float().to(device)).data.cpu().numpy()
    fcp=fc-f_mean
    fgp=np.dot(fcp,gcp.T)
    acc1 = (np.argmax(fgp, axis = 1) == labels_test).sum()
    total1 = len(samples_test)

    acc=(acc-acc1)/(total-total1)
#    print(acc)
    if acc > 0.5:
       if acc > (max(acclist)):
           paraf=model_f.state_dict()
           parag=model_g.state_dict()
#           print('changepara')
           finalacc=acc
    acclist.append(acc)





print(finalacc)
print(acclist)
print(alphalist)
time_end=time.time()
print(time_end-time_start)
