from typing import Sequence
import torch.nn as nn
import torch
from torch import Tensor

from graphormer.model import Graphormer
from graphormer.collator import Batch as BatchedGraphData


class TeacherModel(nn.Module):
    def __init__(self,
        checkpoint_dir: str,
        feat_dist_layers: Sequence[int] = None,
        attnw_dist_layers: Sequence[int] = None
    ):
        super(TeacherModel, self).__init__()
        self.model = Graphormer.load_from_checkpoint(checkpoint_dir)
        self.hidden_dim = self.model.hidden_dim
        self.nhead = self.model.num_heads
        self.feat_dist_layers = feat_dist_layers
        self.attnw_dist_layers = attnw_dist_layers
        self.interm_feat, self.interm_attnw = {}, {}
        self.feat_norm = nn.LayerNorm(normalized_shape=self.hidden_dim, elementwise_affine=False)

        if feat_dist_layers is not None:
            self.add_feat_hook()
        if attnw_dist_layers is not None:
            self.add_attn_hook()

    def feat_hook(self, ind: int):
        def fn(_, __, output):
            feat = output[0]
            mol_feat = self.feat_norm(feat[:, 0, :]).detach().clone()
            self.interm_feat[ind] = mol_feat
        return fn

    def add_feat_hook(self):
        for ind in self.feat_dist_layers:
            self.model.layers[ind].need_attn_weight = True
            self.model.layers[ind].register_forward_hook(self.feat_hook(ind))

    def attnw_hook(self, ind: int):
        def fn(_, __, output):
            attnw = output[1].detach().clone()
            self.interm_attnw[ind] = attnw
        return fn

    def add_attn_hook(self):
        for ind in self.attnw_dist_layers:
            self.model.layers[ind].need_attn_weight = True
            self.model.layers[ind].register_forward_hook(self.attnw_hook(ind))

    def forward(self, data: BatchedGraphData):
        self.interm_feat, self.interm_attnw = {}, {}
        teacher_y_pred = self.model(data)
        feat, attnw = [], []
        for ind in self.feat_dist_layers:
            feat.append(self.interm_feat[ind])
        for ind in self.attnw_dist_layers:
            attnw.append(self.interm_attnw[ind])
        return TeacherOutput(feat=feat, attnw=attnw)

        
class TeacherOutput:
    def __init__(
        self,
        feat: Sequence[Tensor] = None,
        attnw: Sequence[Tensor] = None
    ):
        self.feat = feat
        self.attnw = attnw

    def to(self, device):
        if self.feat is not None:
            for t in self.feat:
                t = t.to(device)
        if self.attnw is not None:
            for t in self.attnw:
                t = t.to(device)
        return self