"""Classifier Model"""
import torch
import torch.nn.functional as F
from torch import nn

from codes.model.base_model import BaseModel
from codes.model.dgl.gcn import GCN
from codes.model.dgl.gat import GAT
from codes.model.dgl.gat_maml import GAT as MAML_GAT
from codes.model.dgl.rgcn import EntityClassify
import copy


class RelationGCNClassifier(BaseModel):
    """RGCN Classifier Model"""

    def __init__(self, config):
        super(RelationGCNClassifier, self).__init__(config)
        self.criteria = nn.NLLLoss()
        self.relation_embedding = nn.Embedding(
            config.model.num_classes, config.model.relation_embedding_dim
        )
        self.layers = nn.ModuleList(
            [
                GCN(config.model.gcn.in_dim, config.model.gcn.hidden_dim, F.relu),
                GCN(config.model.gcn.hidden_dim, config.model.out_dim, F.relu),
            ]
        )

    def forward(self, g, queries):
        # For undirected graphs, in_degree is the same as
        # out_degree.
        h = g.in_degrees().view(-1, 1).float().to(self.config.general.device)
        g.edata["e"] = self.relation_embedding(g.edata["rel"])
        for conv in self.layers:
            h = conv((g, h))
        g.ndata["h"] = h
        return self.dgl_classify(g)


class RelationGATClassifier(BaseModel):
    """Edge GAT classifier"""

    def __init__(self, config):
        super(RelationGATClassifier, self).__init__(config)
        self.criteria = nn.NLLLoss()
        self.relation_embedding = nn.Embedding(
            config.model.num_classes, config.model.relation_embedding_dim
        )
        heads = ([config.model.gat.num_heads] * config.model.gat.num_layers) + [
            config.model.gat.num_out_heads
        ]
        self.gat = GAT(
            config.model.gat.num_layers,
            config.model.gat.in_dim,
            config.model.gat.num_hidden,
            config.model.out_dim,
            heads,
            nn.ReLU(),
            config.model.gat.feat_drop,
            config.model.gat.attn_drop,
            config.model.gat.alpha,
            config.model.relation_embedding_dim,
            config.model.gat.residual,
        )

    def forward(self, g, queries):
        h = g.in_degrees().view(-1, 1).float().to(self.config.general.device)
        e = self.relation_embedding(g.edata["rel"])
        g.edata.update({"e": e})
        h = self.gat(g, h)
        g.ndata["h"] = h
        return self.dgl_classify(g)


class RelationGATMAMLClassifier(BaseModel):
    """Edge GAT MAML classifier"""

    def __init__(self, config):
        super(RelationGATMAMLClassifier, self).__init__(config)
        self.criteria = nn.NLLLoss()
        self.relation_embedding = nn.Embedding(
            config.model.num_classes, config.model.relation_embedding_dim
        )
        heads = ([config.model.gat.num_heads] * config.model.gat.num_layers) + [
            config.model.gat.num_out_heads
        ]
        self.gat = MAML_GAT(
            config.model.gat.num_layers,
            config.model.gat.in_dim,
            config.model.gat.num_hidden,
            config.model.out_dim,
            heads,
            nn.ReLU(),
            config.model.gat.feat_drop,
            config.model.gat.attn_drop,
            config.model.gat.alpha,
            config.model.relation_embedding_dim,
            config.model.gat.residual,
            config.general.device,
        )
        self.weights = copy.deepcopy(self.gat.weights)
        # deleting for good measure
        del self.gat.weights
        self.weight_names = self.gat.weight_names
        w = torch.Tensor(
            size=(config.model.num_classes, config.model.relation_embedding_dim)
        ).to(config.general.device)
        w.requires_grad = True
        self.weights.append(w)
        self.weight_names.append("relation_embedding")
        # classify
        cw = torch.Tensor(size=(config.model.out_dim * 2, config.model.num_classes)).to(
            config.general.device
        )
        cw.requires_grad = True
        cwb = torch.Tensor(size=[config.model.num_classes]).to(config.general.device)
        cwb.requires_grad = True
        self.weights.append(cw)
        self.weight_names.append("classify.weight")
        self.weights.append(cwb)
        self.weight_names.append("classify.bias")

    # def get_param(self, params, partial_name):
    #     match = [v for k,v in params.items() if partial_name in k]
    #     assert len(match) == 1
    #     return match[0]

    def forward(self, g, queries):
        # set appropriate params
        # self.relation_embedding.weight.data = self.get_param(params, 'relation_embedding')
        param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)}
        h = g.in_degrees().view(-1, 1).float().to(self.config.general.device)
        # e = self.relation_embedding(g.edata['rel'])
        e = F.embedding(
            g.edata["rel"],
            self.weights[self.get_param_id(param_name_to_idx, "relation_embedding")],
        )
        g.edata.update({"e": e})
        g.edata.update({"e": e})
        h = self.gat(g, h, self.weights)
        g.ndata["h"] = h
        return self.dgl_classify(g, self.weights, param_name_to_idx)


class RGCNClassifier(BaseModel):
    """
    RGCN classifier
    """

    def __init__(self, config):
        super(RGCNClassifier, self).__init__(config)
        self.criteria = nn.NLLLoss()
        use_cuda = config.general.device == "cuda"
        self.rgcn = EntityClassify(
            config.model.rgcn.n_hidden,
            config.model.out_dim,
            config.model.num_classes,
            num_bases=config.model.rgcn.n_bases,
            num_hidden_layers=config.model.rgcn.n_layers - 2,
            dropout=config.model.rgcn.dropout,
            use_cuda=use_cuda,
        )

    def forward(self, g, queries):
        h = g.in_degrees().view(-1, 1).float().to(self.config.general.device)
        g.ndata["h"] = h
        g.edata["type"] = g.edata["rel"].squeeze(1)
        g.edata["norm"] = torch.ones(
            g.edata["rel"].size(0),
            1,
            device=self.config.general.device,
            dtype=torch.float32,
        )
        # g.edata['norm'] = 1.0
        self.rgcn.features = self.rgcn.create_features(len(g.nodes))
        h = self.rgcn(g)
        g.ndata["h"] = h
        return self.dgl_classify(g)
