import torch
from torch import nn
from torch.nn import functional as F
from torch_sparse import coalesce
from torch_geometric.nn import GAE
from torch_geometric.utils import dropout_adj, add_remaining_self_loops, remove_self_loops
from learner import Learner


class Meta(nn.Module):
    def __init__(self, label_gene_ext_config, label_gene_cla_config, extractor_config, classifier_config, cargs):
        super(Meta, self).__init__()
        self.model_name = 'AL-GCN'

        self.label_generator_ext = GAE(Learner(label_gene_ext_config))
        self.label_generator_cla = Learner(label_gene_cla_config)
        self.extractor = GAE(Learner(extractor_config))
        self.classifier = Learner(classifier_config)

        self.lr = cargs.lr
        self.lr = cargs.lr
        self.edge_sample_rate = cargs.sample_rate
        self.hidden_len = cargs.hidden
        self.edge_tau = cargs.edge_tau

        self.classifier_optim = torch.optim.Adam([{'params': self.classifier.parameters()}], lr=self.lr, weight_decay=5e-4)
        self.extractor_optim = torch.optim.Adam(self.extractor.parameters(), lr=self.lr, weight_decay=5e-4)
        self.label_gene_optim = torch.optim.Adam([{'params': self.label_generator_ext.parameters()},
                                                  {'params': self.label_generator_cla.parameters()}], lr=self.lr, weight_decay=5e-4)

        self.iter = 0

    def forward(self, data):
        x, edge_index, edge_attr, = data.x, data.edge_index, data.edge_attr

        if self.iter == 0:
            self.cache_feature_index = data.edge_index
            self.cache_feature_value = self.cache_feature_index.new_ones((self.cache_feature_index.size(-1), ), dtype=torch.float)
            self.ori_edges_num = data.edge_index.size(-1)
        self.iter += 1

###--------------------------------------------------------------------------------------------
### train model -------------------------------------------------------------------------------
###--------------------------------------------------------------------------------------------
        self.classifier_optim.zero_grad()
        self.extractor_optim.zero_grad()
        self.label_gene_optim.zero_grad()

        gene_z= self.label_generator_ext.encode(x, edge_index, edge_attr)
        gene_logits = self.label_generator_cla(gene_z, edge_index, edge_attr)

        model_z = self.extractor.encode(x, self.cache_feature_index, self.cache_feature_value)
        model_logits = self.classifier(model_z, edge_index, edge_attr)

        loss1 = F.nll_loss(model_logits[data.train_mask], data.y[data.train_mask])
        loss_same = F.mse_loss(model_logits, gene_logits)

        model_loss = loss_same + 0.5*loss1 

        model_loss.backward()
        self.classifier_optim.step()

        
        model_z = self.extractor.encode(x, edge_index, edge_attr, is_training=False)

        sampled_two_hop, _ = dropout_adj(data.two_hop_index, p = 0.9, force_undirected=True)
        
        self.cache_feature_index = torch.cat([edge_index, self.cache_feature_index, sampled_two_hop],dim=-1)
        self.cache_feature_index, _ = coalesce(self.cache_feature_index, None, data.num_nodes, data.num_nodes)
        
        self.cache_feature_index, _ = remove_self_loops(self.cache_feature_index)

        feature_sim = self.extractor.decode(model_z, self.cache_feature_index, sigmoid=True).detach()*2-1
        edge_mask = feature_sim > self.edge_tau
        row, col = self.cache_feature_index
        row, col = row[edge_mask], col[edge_mask]
        self.cache_feature_index = torch.stack([row, col], dim=0)
        self.cache_feature_value = feature_sim[edge_mask]

        self.cache_feature_index, self.cache_feature_value = add_remaining_self_loops(self.cache_feature_index, self.cache_feature_value, num_nodes = data.num_nodes, fill_value = 1.)



###--------------------------------------------------------------------------------------------
###--------------------------------------------------------------------------------------------
###--------------------------------------------------------------------------------------------


###--------------------------------------------------------------------------------------------
### train label generator ---------------------------------------------------------------------
###--------------------------------------------------------------------------------------------

        self.classifier_optim.zero_grad()
        self.label_gene_optim.zero_grad()
        self.extractor_optim.zero_grad()

        gene_z= self.label_generator_ext.encode(x, edge_index, edge_attr)
        gene_logits = self.label_generator_cla(gene_z, edge_index, edge_attr)

        model_z = self.extractor.encode(x, self.cache_feature_index, self.cache_feature_value)
        model_logits = self.classifier(model_z, edge_index, edge_attr)

        loss1 = F.nll_loss(model_logits[data.train_mask], data.y[data.train_mask])
        loss_same = F.mse_loss(model_logits, gene_logits)

        loss = loss_same + 0.5*loss1


        grad = torch.autograd.grad(loss, self.classifier.parameters(), create_graph=True)
        fast_weights = list(map(lambda p: p[1] - self.lr * p[0], zip(grad, self.classifier.parameters())))

        after_model_logits = self.classifier(model_z, edge_index, edge_attr, vars=fast_weights)

        gene_2nd_loss_label = F.nll_loss(after_model_logits[data.train_mask], data.y[data.train_mask])
        label_real_loss = F.nll_loss(gene_logits[data.train_mask], data.y[data.train_mask])

        label_loss = 0.5*label_real_loss + 0.5*gene_2nd_loss_label

        label_loss.backward()
        self.label_gene_optim.step()
###--------------------------------------------------------------------------------------------
###--------------------------------------------------------------------------------------------
###--------------------------------------------------------------------------------------------

###--------------------------------------------------------------------------------------------
### train feature extractor -------------------------------------------------------------------
###--------------------------------------------------------------------------------------------

        self.classifier_optim.zero_grad()
        self.label_gene_optim.zero_grad()
        self.extractor_optim.zero_grad()

        gene_z= self.label_generator_ext.encode(x, edge_index, edge_attr)
        gene_logits = self.label_generator_cla(gene_z, edge_index, edge_attr)

        model_z = self.extractor.encode(x, self.cache_feature_index, self.cache_feature_value)
        model_logits = self.classifier(model_z, edge_index, edge_attr)

        loss1 = F.nll_loss(model_logits[data.train_mask], data.y[data.train_mask])
        loss_same = F.mse_loss(model_logits, gene_logits)

        loss = loss_same + 0.5*loss1

        grad_cla = torch.autograd.grad(loss, self.classifier.parameters(), create_graph=True)
        fast_weights_cla = list(map(lambda p: p[1] - self.lr * p[0], zip(grad_cla, self.classifier.parameters())))

        after_model_logits_cla = self.classifier(model_z, edge_index, edge_attr, vars=fast_weights_cla)

        gene_2nd_loss_cla = F.nll_loss(after_model_logits_cla[data.train_mask], data.y[data.train_mask]) #Best

        tmp_edge_index, _ = dropout_adj(data.edge_index, p = 1-self.edge_sample_rate, force_undirected=True)
        gae_z = self.extractor.encode(x, tmp_edge_index, edge_attr)
        gae_loss = self.extractor.recon_loss(gae_z, edge_index)

        adj_loss = gae_loss + 0.5*gene_2nd_loss_cla

        adj_loss.backward()
        self.extractor_optim.step()
##--------------------------------------------------------------------------------------------
##--------------------------------------------------------------------------------------------
##--------------------------------------------------------------------------------------------





    @torch.no_grad()
    def model_eval(self, data):
        x, edge_index, edge_attr, y = data.x, data.edge_index, data.edge_attr, data.y

        model_z = self.extractor.encode(x, self.cache_feature_index, self.cache_feature_value, is_training=False)
        model_logits = self.classifier(model_z, edge_index, edge_attr, is_training=False)

        accs = []
        for _, mask in data('train_mask', 'val_mask', 'test_mask'):
            pred = model_logits[mask].max(1)[1]
            acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
            accs.append(acc)

        return accs
