from os import dup2
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.nn.functional as F
#from utils import train
import torchvision.models as models
import time
import matplotlib.pyplot as plt
import cvxpy as cvx
import scipy.io as scio
import scipy as spy
time_start=time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

x1=np.load('./../feature/feature_a_4096.npy')
y1=np.load('./../feature/label_a.npy')
x2=np.load('./../feature/feature_d_4096.npy')
y2=np.load('./../feature/label_d.npy')
x3=np.load('./../feature/feature_w_4096.npy')
y3=np.load('./../feature/label_w.npy')

#trainset_target
index=np.random.choice(x2[y2==0].shape[0],3,replace=False)
refx=x2[y2==0][index]
refy=[0,0,0]
for i in range(1,31):
    index=np.random.choice(x2[y2==i].shape[0],3,replace=False)
    refx=np.vstack((refx,x2[y2==i][index]))
    refy=np.append(refy,[i,i,i])
refy=refy.astype(int)

#trainset_test
index=np.random.choice(x1[y1==0].shape[0],20,replace=False)
sourcex=x1[y1==0][index]
sourcey=0*np.ones(20)
for i in range(1,31):
    index=np.random.choice(x1[y1==i].shape[0],20,replace=False)
    sourcex=np.vstack((sourcex,x1[y1==i][index]))
    sourcey=np.append(sourcey,i*np.ones(20))
sourcey=sourcey.astype(int)

#feature---------------
class Net_f(nn.Module):
    def __init__(self):
        super(Net_f, self).__init__()
        self.fc1 = nn.Linear(4096,1024)
        self.fc2 = nn.Linear(1024,64)

    def forward(self,x):
        out=F.relu(self.fc1(x))
        out=self.fc2(out)
        return out       


class Net_g(nn.Module):
    def __init__(self,num_class=31, dim=64):
        super(Net_g, self).__init__()

        self.fc=nn.Linear(num_class, dim)

    def forward(self,x):
        out=self.fc(x)

        return out

def corr(f,g):
    k = torch.mean(torch.sum(f*g,1))
    return k
    
def cov_trace(f,g):
    cov_f = torch.mm(torch.t(f),f) / (f.size()[0]-1.)
    cov_g = torch.mm(torch.t(g),g) / (g.size()[0]-1.)
    return torch.trace(torch.mm(cov_f, cov_g))

def neg_hscore(f,g):
    f0 = f - torch.mean(f,0)
    g0 = g - torch.mean(g,0)
    corr = torch.mean(torch.sum(f0*g0,1))
    cov_f = torch.mm(torch.t(f0),f0) / (f0.size()[0]-1.)
    cov_g = torch.mm(torch.t(g0),g0) / (g0.size()[0]-1.)
    return - corr + torch.trace(torch.mm(cov_f, cov_g)) / 2.

lr=0.00002
epoch=500
ind=0
dim=64
model_f = Net_f().to(device)
model_gt = Net_g().to(device)
model_gs = Net_g().to(device)
optimizer_fg = torch.optim.Adam(list(model_f.parameters())+list(model_gt.parameters())+list(model_gs.parameters()),lr=lr)
losslist=[]
acclist=[0]
alpha=[0.7,0.3]
alphalist=[]
samples_ref=torch.from_numpy(refx)
nt=samples_ref.size()[0]
labels_ref=torch.from_numpy(refy)
labels_one_hot_ref = torch.zeros(len(labels_ref), 31).scatter_(1, labels_ref.view(-1,1), 1)
samples_trans=torch.from_numpy(sourcex)
ns=samples_trans.size()[0]
labels_trans=torch.from_numpy(sourcey)
labels_one_hot_trans= torch.zeros(len(labels_trans), 31).scatter_(1, labels_trans.view(-1,1), 1)

for i in range(epoch):
    model_f.train()
    model_gt.train()
    model_gs.train()
    
    f_ref=model_f(Variable(samples_ref).float().to(device))
    g_ref=model_gt(Variable(labels_one_hot_ref).float().to(device))
    f0_ref = f_ref - torch.mean(f_ref,0)
    g0_ref = g_ref - torch.mean(g_ref,0)
#    f_trans=model_f(Variable(samples_trans).float().to(device))
#    g_trans=model_g(Variable(labels_one_hot_trans).float().to(device))
#    f_trans=f_trans-torch.mean(f_trans,0)
#    g_trans=g_trans-torch.mean(g_trans,0)

    f_trans=model_f(Variable(samples_trans).float().to(device))#-torch.mean(f_ref,0)
    g_trans=model_gs(Variable(labels_one_hot_trans).float().to(device))#- torch.mean(g_ref,0)
    f_trans=f_trans-torch.mean(f_trans,0)
    g_trans=g_trans-torch.mean(g_trans,0)

#-------renew alpha
#----------have to convert to python data first
#    ft=f0_ref.detach().numpy()
#    fs=f_trans.detach().numpy()
#    lambdaf=ft.T.dot(ft)/nt
#    dd=0
#    for i in range(31):
#        dd+=(1/31*(np.mean(ft[3*i:3*i+3],0)-np.mean(fs[8*i:8*i+8],0)).reshape(1,dim).dot(np.linalg.inv(lambdaf)).dot((np.mean(ft[3*i:3*i+3],0)-np.mean(fs[8*i:8*i+8],0)).reshape(dim,1)))[0][0]
#    bes1=scipy.special.iv(dim*31/2,dd/(1/nt+1/ns))
#    bes2=scipy.special.iv(dim*31/2-1,dd/(1/nt+1/ns))
#    rdd=dd/(1/nt+1/ns)
#    regdim=dim*31/2
#    alphat=(1/ns)/(1/nt+1/ns)+(1/nt)/(1/nt+1/ns)*(rdd-regdim)/rdd
#    alphalist.append(alphat)
#    alpha=[alphat,1-alphat]
#
    optimizer_fg.zero_grad()
    pyrt=1/np.sqrt(31)
    loss=(-2)*alpha[0]*corr(f0_ref,g0_ref)
    loss+=(-2)*alpha[1]*corr(f_trans,g_trans)
    loss+=alpha[0]*cov_trace(f0_ref,g0_ref)
    loss+=alpha[1]*cov_trace(f_trans,g_trans)
    losslist.append(loss.item())
    loss.backward()
    optimizer_fg.step()
    
#------acc
    model_f.eval()
    model_gt.eval()
    model_gs.eval()
    fc = model_f(Variable(samples_ref).float().to(device)).data.cpu().numpy()
    f_mean = np.sum(fc,axis=0)/fc.shape[0]
    fc=fc-f_mean
    fcs=model_f(Variable(samples_trans).float().to(device)).data.cpu().numpy()
    f_means = np.sum(fcs,axis=0)/fcs.shape[0]
    fcs=fcs-f_means
    labellist = torch.Tensor(np.eye(31))
    gc = model_gt(Variable(labellist).to(device)).data.cpu().numpy()
    gce = np.sum(gc,axis=0)/gc.shape[0]
    gcp = gc-gce

    gcs = model_gs(Variable(labellist).to(device)).data.cpu().numpy()
    gcse = np.sum(gcs,axis=0)/gcs.shape[0]
    gcs = gcs-gcse
    
    lambdat=np.dot(fc.T,fc)/fc.shape[0]
    lambdas=np.dot(fcs.T,fcs)/fcs.shape[0]
    wt, vt = np.linalg.eig(lambdat)
    ws, vs = np.linalg.eig(lambdas)
    lambdath=vt.dot(np.diag(np.sqrt(wt))).dot(vt.T)
    lambdash=vs.dot(np.diag(np.sqrt(ws))).dot(vs.T)

    #---
    regdim=dim*31/2
    dd=1/31*np.sum(np.square(lambdath.dot(gcp.T)-lambdash.dot(gcs.T)))
    rdd=dd/(1/nt/2+1/ns/2)
    aa=(1/ns)/(1/nt+1/ns)+(1/nt)/(1/nt+1/ns)*(2*regdim-rdd)/(2*regdim)
    alphalist.append(aa)
    if aa < 0.8:
       aa=0.874
    gcp=gcp*aa+(1-aa)*(np.dot(np.linalg.inv(lambdath),lambdash).dot(gcs.T)).T

    samples_test=torch.from_numpy(x2)
    labels_test = y2
    fc=model_f(Variable(samples_test).float().to(device)).data.cpu().numpy()
    fcp=fc-f_mean
    fgp=np.dot(fcp,gcp.T)
    acc = (np.argmax(fgp, axis = 1) == labels_test).sum()
    total = len(samples_test)

    samples_test=torch.from_numpy(refx)
    labels_test = refy
    fc=model_f(Variable(samples_test).float().to(device)).data.cpu().numpy()
    fcp=fc-f_mean
    fgp=np.dot(fcp,gcp.T)
    acc1 = (np.argmax(fgp, axis = 1) == labels_test).sum()
    total1 = len(samples_test)

    acc=(acc-acc1)/(total-total1)
#    print(acc)
    if acc > 0.5:
       if acc > (max(acclist)):
           paraf=model_f.state_dict()
           parag=model_gt.state_dict()
#           print('changepara')
           finalacc=acc
    acclist.append(acc)





print(finalacc)
print(acclist)
print(alphalist)
time_end=time.time()
print(time_end-time_start)
