from data import *
from net import *
from lib import *
import datetime
from tqdm import tqdm
if is_in_notebook():
    from tqdm import tqdm_notebook as tqdm
from torch import optim
from tensorboardX import SummaryWriter
import torch.backends.cudnn as cudnn
import config


class CLS_graph(nn.Module):
    """
    linear projection for classification
    """
    def __init__(self, in_dim, out_dim, source_classes, embedding_dim=256):
        super(CLS_graph, self).__init__()
        self.bottleneck = nn.Linear(in_dim, embedding_dim)
        if out_dim == 15:
            embedding_file = "officehome.pkl"
        else:
            embedding_file = "office31_reorganized.pkl"
        with open(embedding_file, 'rb') as f:
            loadData = pickle.load(f)
        loadData = np.array(loadData)
        self.inp = torch.tensor(loadData[source_classes], dtype=torch.float).cuda() 
        self.projection_graph1 = GCN(  ##### projection graph
            nfeat=2048, nclasses=2048
            )
        self.projection_graph = GCN(  ##### projection graph
            nfeat=2048, nclasses=300
            )
        self.mean_feats = torch.zeros(out_dim, 2048).cuda()
        self.adj_feats = torch.zeros(out_dim, out_dim).cuda()
        self.sigma = 0.05
        self.beta = 0.8
        self.epison = 1e-5
        self.w_adj = torch.nn.Parameter(torch.rand(len(source_classes), len(source_classes)).cuda(), requires_grad=True)

        self.num_classes = out_dim
        self.softmax = nn.Softmax(dim=-1)

    def euclid_dist(self, x, y):
        x_sq = (x ** 2).mean(-1)
        x_sq_ = torch.stack([x_sq] * y.size(0), dim = 1)
        y_sq = (y ** 2).mean(-1)
        y_sq_ = torch.stack([y_sq] * x.size(0), dim = 0)
        xy = torch.mm(x, y.t()) / x.size(-1)
        dist = x_sq_ + y_sq_ - 2 * xy

        return dist

    def update_mean_adj(self, x_feat, y_true):  #### checked

        onehot_label = torch.zeros((y_true.shape[0], self.num_classes)).scatter_(1, y_true.unsqueeze(
                        -1).cpu(), 1).float().cuda()

        domain_feature = x_feat.unsqueeze(1) * onehot_label.unsqueeze(-1)
        tmp_mean = domain_feature.sum(0) / (onehot_label.unsqueeze(-1).sum(0) + self.epison)
        curr_mask = (tmp_mean.sum(-1) != 0).float().unsqueeze(-1)
        self.mean_feats = self.mean_feats.detach() * (1 - curr_mask) + (
                        self.mean_feats.detach() * self.beta + tmp_mean * (1 - self.beta)) * curr_mask
        curr_dist = self.euclid_dist(self.mean_feats, self.mean_feats)
        self.adj_feats = torch.exp(-curr_dist / (2 * self.sigma ** 2))
        # self.adj_feats = self.bin_over_smooth(self.adj_feats)  #####  bin

    def bin_over_smooth(self, adj):
        _, indices = torch.topk(adj, k=4, largest=True)
        x1 = torch.zeros_like(adj)
        _adj = x1.scatter(1, indices, 1)
        
        _adj = _adj / (_adj.sum(0, keepdims=True) + 1e-6)
        return _adj
    
    def construct_adj(self, feats):
        dist = self.euclid_dist(self.mean_feats, feats)
        sim = torch.exp(-dist / (2 * self.sigma ** 2))

        E = torch.eye(feats.shape[0]).float().cuda()    ##### final adj

        adj_feats_now = torch.mm(self.adj_feats, self.w_adj)
        # A = torch.cat([adj_feats_now, torch.zeros_like(sim)], dim = 1)  
        A = torch.cat([adj_feats_now, sim], dim = 1)   ##### (adj, sim)
        B = torch.cat([sim.t(), E], dim = 1)            ##### (sim.t, E)
        gcn_adj = torch.cat([A, B], dim = 0)
        
        return gcn_adj

    def forward(self, x, y_true = None):
        out = [x]

        if y_true is not None:
            self.update_mean_adj(x, y_true)  ### update mean_feats and adj feats

        graph_feats = torch.cat([self.mean_feats, x], dim=0)
        graph_adj = self.construct_adj(x)

        x = self.projection_graph(graph_feats, graph_adj) ###### project to embeddings
        ####   2GCN
        # x = self.projection_graph1(graph_feats, graph_adj)
        # x = self.projection_graph(x, graph_adj)
        #### 

        ####   middle EGLayer
        # x = self.projection_graph1(graph_feats, graph_adj)
        # x = self.bottleneck(x)
        ####  
        x1 = x[:self.num_classes, :]
        x = x[self.num_classes:, :]
        out.append(x)
        x = F.normalize(x, dim=-1)
        inp = F.normalize(self.inp, dim=-1)
        x = x.squeeze() @ inp.t() 
        out.append(x)
        # x = self.softmax(x)
        out.append(x)
        out.append(graph_adj)
        out.append(self.inp)
        out.append(x1)
        return out
