import copy
import math
from argparse import ArgumentParser
from itertools import chain
from typing import List, Optional, Sequence, Tuple

import einops
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor
from torch.nn import Module
from torch.nn.init import xavier_uniform_
from torch.nn.modules.normalization import LayerNorm
from torch.optim import AdamW, optimizer
from torch.utils.data import dataset

from data import BatchedSmilesData, BatchedDistillationData, get_smiles_dataset_info
from teacher import TeacherModel, TeacherOutput
from utils import get_sinusoid_pe_tensor


def attention(
    query: Tensor, key: Tensor, value: Tensor, 
    mask: Optional[Tensor] = None, bias: Optional[Tensor] = None,
    dropout: Optional[Module] = None) -> Tuple[Tensor, Tensor]:
    r"""
    Compute scaled dot product attention.

    Tensor shapes:
        L: sequence length, B: batch size, H: attention heads.
        Inputs:
            query: (B, H, L, d_q),
            key: (B, H, L, d_k),
            value: (B, H, L, d_v)
            mask: (B, 1, L, L)
        Outputs:
            result: (B, H, L, d_v)
            weights: (B, H, L, L)
    """
    d_q, d_k = query.size(-1), key.size(-1)
    assert d_q == d_k
    scores = torch.einsum(
        "b h i k, b h j k -> b h i j", query, key) / math.sqrt(d_k)
    # attention bias
    if bias is not None:
        scores = scores + bias
    # mask and score
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -5e4)
    scores = F.softmax(scores, dim = -1)
    if dropout is not None:
        weights = dropout(scores)
    result = torch.einsum(
        "b h i k, b h k j -> b h i j", weights, value)
    return result, scores


class MultiHeadedAttention(nn.Module):

    def __init__(self, h: int, d_model: int, d_k: Optional[int] = None, 
                 dropout: float = 0.1, use_bias: bool = False):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.h = h
        self.d_k = d_model // h if d_k is None else d_k
        self.d_v = d_model // h
        self.d_model = d_model
        self.q_proj = nn.Linear(d_model, self.d_k * h)
        self.k_proj = nn.Linear(d_model, self.d_k * h)
        self.v_proj = nn.Linear(d_model, self.d_v * h)
        self.o_proj = nn.Linear(self.d_v * h, d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.use_bias = use_bias
        if use_bias:
            self.bias_proj_q = nn.Linear(d_model, self.d_k * h)
            self.bias_proj_k = nn.Linear(d_model, self.d_k * h)

    def forward(
        self, query: Tensor, key: Tensor,
        value: Tensor, mask: Optional[Tensor] = None,
        bias_mask: Optional[Tensor] = None,
    ):
        r"""
        Tensor shapes:
            L: sequence length, B: batch size.
            Inputs:
                query: (B, L, d_model)
                key: (B, L, d_model)
                value: (B, L, d_model)
                mask: (B, L, L)
                bias_mask: (B, L, L)
            Outputs:
                x: (B, L, d_model)
        """
        q = rearrange(self.q_proj(query), 'b l (h d) -> b h l d', h=self.h)
        k = rearrange(self.k_proj(key), 'b l (h d) -> b h l d', h=self.h)
        v = rearrange(self.v_proj(value), 'b l (h d) -> b h l d', h=self.h)

        bias = None
        # bias should have shape (b, h, l, l) or (b, 1, l, l)
        if self.use_bias:
            bias_q = self.bias_proj_q(query)
            bias_k = self.bias_proj_k(key)
            bias_q = rearrange(bias_q, 'b l (h d) -> b h l d', h=self.h)
            bias_k = rearrange(bias_k, 'b l (h d) -> b h l d', h=self.h)
            bias = torch.einsum('b h i k, b h j k -> b h i j', bias_q, bias_k)
            if bias_mask is not None:
                if len(bias_mask.shape) == 3:
                    bias_mask = repeat(bias_mask, 'b i j -> b c i j', c=1)
                bias = bias * bias_mask
        if mask is not None:
            mask = repeat(mask, 'b l1 l2 -> b k l1 l2', k=1)
        
        x, attn_weights = attention(q, k, v, mask, bias=bias, dropout=self.dropout)
        # x: (B, H, L, d_v)
        x = rearrange(x, 'b h l dv -> b l (h dv)').contiguous()
        x = self.o_proj(x)
        return x, bias


class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model: int, nhead: int, d_k: Optional[int] = None, 
                 dim_feedforward: int = 2048, dropout: float = 0.1, 
                 layer_norm_eps: float = 1e-5, use_bias: bool = False):

        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiHeadedAttention(
            h=nhead, d_model=d_model, d_k=d_k,
            dropout=dropout, use_bias=use_bias,
        )
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)

        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = F.gelu

    def forward(
        self, src: Tensor,
        src_mask: Optional[Tensor] = None,
        attn_bias_mask: Optional[Tensor] = None
    ):
        r""" 
        Tensor shapes:
            L: sequence length, B: batch size.
            Inputs:
                src: (B, L, d_model)
                src_mask: (B, L, L)
            Outputs:
                x: (B, L, d_model)

        Notice: src_mask is the attention weights mask, 0 for masked.
        """
        src2, bias = self.self_attn(src, src, src, src_mask, attn_bias_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src, bias


class EncoderOutput:
    def __init__(
        self,
        intermediate_states: Sequence[Tensor] = None,
        attn_bias: Sequence[Tensor] = None,
        feature_pred: Tensor = None,
        y_pred: Tensor = None
    ):
        self.intermediate_states = intermediate_states
        self.attn_bias = attn_bias
        self.feature_pred = feature_pred
        self.y_pred = y_pred

    def to(self, device):
        for t in self.intermediate_features:
            t = t.to(device)
        for t in self.attn_biases:
            t = t.to(device)
        self.feature_pred = self.feature_pred.to(device)
        self.y_pred = self.y_pred.to(device)
        return self


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        # encoder layer params
        d_model: int, nhead: int, 
        d_k: Optional[int] = None, 
        dim_feedforward: int = 2048, 
        dropout: float = 0.1, 
        layer_norm_eps: float = 1e-5,
        # encoder params
        num_encoder_layers: int = 6, 
        biased_attn_layers: Sequence[int] = None,
        output_hidden_states_layers: Sequence[int] = None,
    ):
        super(TransformerEncoder, self).__init__()
        # create encoder layers
        layers = []
        for i in range(num_encoder_layers):
            use_bias = i in biased_attn_layers if biased_attn_layers is not None else False
            layers.append(TransformerEncoderLayer(
                d_model=d_model, nhead=nhead, d_k=d_k,
                dim_feedforward=dim_feedforward, dropout=dropout,
                layer_norm_eps=layer_norm_eps, use_bias=use_bias
            ))
        self.initial_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.layers = nn.ModuleList(layers)
        self.num_layers = num_encoder_layers
        self.output_hidden_states_layers = output_hidden_states_layers or []

    def __len__(self):
        return self.num_layers
    
    def forward(
        self,
        src: Tensor,
        src_mask: Optional[Tensor],
        attn_bias_mask: Optional[Tensor]
    ) -> EncoderOutput:
        r""" 
        Tensor shapes:
            L: sequence length, B: batch size.
            Inputs:
                src: (B, L, d_model)
                src_mask: (B, L, L)
            Outputs:
                output: (B, L, d_model)

        Notice: src_mask is the attention weights mask, 0 for masked.
        """
        output = src
        intermediate_states, attn_bias = [], []
        output = self.initial_norm(output)

        for ind, layer in enumerate(self.layers):
            output, bias = layer(output, src_mask, attn_bias_mask)
            if bias is not None:
                attn_bias.append(bias)
            if ind in self.output_hidden_states_layers:
                intermediate_states.append(output[:, 0, :])

        output = output[:, 0, :]

        if len(attn_bias) == 0:
            attn_bias = None
        if len(intermediate_states) == 0:
            intermediate_states = None

        return EncoderOutput(
            intermediate_states=intermediate_states,
            attn_bias=attn_bias,
            feature_pred=output,
        )
        

class SmilesTransformerPE(nn.Module):
    r"""
    Positional encoding block for SMILES transformer.
    """
    def __init__(self, pe_type: str, pe_dim: int, max_len: int, scale_factor: float = 1.0):
        super(SmilesTransformerPE, self).__init__()
        self.scale_factor = scale_factor
        self.pe_type = pe_type
        self.pe_dim = pe_dim
        if pe_type == 'sinusoid':
            pe_tensor = get_sinusoid_pe_tensor(d_model=pe_dim, maxlen=max_len)
            pe_tensor.requires_grad = False
            self.register_buffer('pe_tensor', pe_tensor)
        elif pe_type == 'learnable':
            learnable_pe_tensor = torch.rand((max_len-1, pe_dim))
            self.learnable_pe_tensor = nn.Parameter(learnable_pe_tensor)
            #self.pe_tensor = torch.cat([torch.zeros((1, pe_dim)), self.learnable_pe_tensor])
        elif pe_type != 'none':
            pass
    
    def forward(self, x: Tensor, pe_index: Tensor) -> Tensor:
        r"""
        Tensor shapes:
            x: (B, L, d_model)
            pe_index: (B, L, 2)
        """
        if self.pe_type == 'learnable':
            z = torch.zeros((1, self.pe_dim), device='cuda')
            self.pe_tensor = torch.cat([z, self.learnable_pe_tensor])
        if self.pe_type != 'none':
            #print(pe_index.device, self.pe_tensor.device, self.pe_tensor[0], self.pe_tensor[1])
            pe = F.embedding(pe_index[: ,:, 0], self.pe_tensor) +\
                F.embedding(pe_index[:, :, 1], self.pe_tensor)
            return x * self.scale_factor + pe
        else:
            return x


class OfflineDistillationLoss():

    def __init__(
        self,
        task_loss_fn,
        dist_loss_fn, # pass None if no distillation loss
        attn_bias_loss_fn, # must be loss function from torch.nn.functional
        coef_task_loss: float = 1.0,
        coef_dist_loss: float = 1e-3,
        coef_attn_bias_loss: float = 1e-2,
    ):
        self.task_loss_fn = task_loss_fn
        self.dist_loss_fn = dist_loss_fn
        self.attn_bias_loss_fn = attn_bias_loss_fn
        self.coef_task_loss = coef_task_loss
        self.coef_dist_loss = coef_dist_loss
        self.coef_attn_bias_loss = coef_attn_bias_loss

    def __call__(
        self, input_data: EncoderOutput,
        target_data: BatchedSmilesData,
        dist_transform = None,
    ):
        # task loss
        task_loss = self.task_loss_fn(input_data.y_pred, target_data.y)

        # distillation loss
        dist_loss = 0.0
        if self.dist_loss_fn is not None:
            input_hidden_states = input_data.intermediate_states
            target_hidden_states = target_data.feature
            assert len(input_hidden_states) == len(target_hidden_states)
            for cnt, (input_feature, target_feature) in enumerate(zip(input_hidden_states, target_hidden_states)):
                if dist_transform is not None:
                    if isinstance(dist_transform, nn.ModuleList):
                        input_feature = dist_transform[cnt](input_feature)
                    else:
                        input_feature = dist_transform(input_feature)
                dist_loss += self.dist_loss_fn(input_feature, target_feature)
    
        # attention bias loss
        attnb_loss = 0.0
        if self.attn_bias_loss_fn is not None:
            input_attnb = input_data.attn_bias
            target_attnb = target_data.target_attnb
            attnb_mask = target_data.attnb_mask
            for attnb in input_attnb:
                attnb = attnb * attnb_mask
                non_zero_elements = attnb_mask.sum()
                attnb_loss += self.attn_bias_loss_fn(attnb, target_attnb, reduction='sum') / non_zero_elements

        loss = self.coef_task_loss * task_loss + self.coef_dist_loss * dist_loss + self.coef_attn_bias_loss * attnb_loss
        loss_dict = {
            'train_loss': loss,
            'task_loss': task_loss,
        }
        if dist_loss > 1e-9:
            loss_dict['distillation_loss'] = dist_loss
        if attnb_loss > 1e-9:
            loss_dict['attention_bias_loss'] = attnb_loss
        return loss, loss_dict


class DistillationLoss:
    def __init__(
        self,
        task_loss_fn,
        task_loss_weight: float = 0.0,
        feat_dist_loss_weight: float = 0.0,
        attnw_dist_loss_weight: float = 0.0,
        warmup_epochs: int = -1,
        warmup_task_loss_weight: float = 0.0,
        warmup_feat_dist_loss_weight: float = 0.0,
        warmup_attnw_dist_loss_weight: float = 0.0,
    ):
        self.task_loss_fn = task_loss_fn
        self.task_loss_weight = task_loss_weight
        self.feat_dist_loss_weight = feat_dist_loss_weight
        self.attnw_dist_loss_weight = attnw_dist_loss_weight
        self.warmup_epochs = warmup_epochs
        self.warmup_task_loss_weight = warmup_task_loss_weight
        self.warmup_feat_dist_loss_weight = warmup_feat_dist_loss_weight
        self.warmup_attnw_dist_loss_weight = warmup_attnw_dist_loss_weight

    def __call__(
        self,
        student_data: EncoderOutput,
        teacher_data: TeacherOutput,
        smiles_data: BatchedSmilesData,
        current_epoch: int = 10000,
        feat_dist_transform = None,
        attnw_dist_transform = None,
    ):
        task_loss = self.task_loss_fn(student_data.y_pred, smiles_data.y)

        feat_dist_loss = 0.0
        if self.feat_dist_loss_weight > 1e-8:
            student_feat_list = student_data.intermediate_states
            teacher_feat_list = teacher_data.feat
            assert len(student_feat_list) == len(teacher_feat_list)
            for student_feat, teacher_feat in zip(student_feat_list, teacher_feat_list):
                if feat_dist_transform is not None:
                    student_feat = feat_dist_transform(student_feat)
                feat_dist_loss += F.mse_loss(student_feat, teacher_feat, reduction='mean')

        attnw_dist_loss = 0.0
        if self.attnw_dist_loss_weight > 1e-8:
            student_attnb_list = student_data.attn_bias
            teacher_attnw_list = teacher_data.attnw
            assert len(student_attnb_list) == len(teacher_attnw_list)
            for student_attnb, teacher_attnw in zip(student_attnb_list, teacher_attnw_list):
                seq_length = teacher_attnw.shape[2]
                student_attnb = student_attnb[:, :, 0 : seq_length, 0 : seq_length]
                mask = smiles_data.atom_token_attn_mask
                mask = repeat(mask, 'b i j -> b c i j', c=1)
                mask = mask[:, :, 0 : seq_length, 0 : seq_length]
                if attnw_dist_transform is not None:
                    student_attnb = rearrange(student_attnb, 'b h i j -> b i j h')
                    student_attnb = attnw_dist_transform(student_attnb)
                    student_attnb = rearrange(student_attnb, 'b i j h -> b h i j')
                student_attnw = F.softmax(student_attnb, dim=-1)
                attnw_dist_loss += F.mse_loss(student_attnw, mask * teacher_attnw, reduction='mean')

        if current_epoch < self.warmup_epochs:
            loss = self.warmup_task_loss_weight * task_loss + \
                self.warmup_feat_dist_loss_weight * feat_dist_loss + \
                self.warmup_attnw_dist_loss_weight * attnw_dist_loss
        else:
            loss = self.task_loss_weight * task_loss + \
                self.feat_dist_loss_weight * feat_dist_loss + \
                self.attnw_dist_loss_weight * attnw_dist_loss
        
        loss_dict = {
            'train_loss': loss,
            'task_loss': task_loss,
        }
        if feat_dist_loss > 1e-9:
            loss_dict['feature_distillation_loss'] = feat_dist_loss
        if attnw_dist_loss > 1e-9:
            loss_dict['attention_weight_distillation_loss'] = attnw_dist_loss
        return loss, loss_dict


class SmilesTransformer(nn.Module):
    r"""
    Smiles Transformer module, no positional embeddings.
    """
    def __init__(
        self, vocab_size: int = 40,
        d_model: int = 256,
        nhead: int = 4,
        dim_feedforward: int = 256,
        dropout: float = 0.1,
        num_layers: int = 4, 
        layer_norm_eps: float = 1e-5,
        max_len: int = 100,
        pe_type: str = 'learnable',
        pe_scale_factor: float = 1.0,
        biased_attn_layers: Sequence[int] = None,
        output_hidden_states_layers: Sequence[int] = None,
    ):
        super(SmilesTransformer, self).__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.nhead = nhead
        self.embed_src = nn.Embedding(vocab_size, d_model)
        self.pe = SmilesTransformerPE(
            pe_type=pe_type, pe_dim=d_model, 
            max_len=max_len, scale_factor=pe_scale_factor
        )
        self.encoder = TransformerEncoder(
            d_model=d_model, nhead=nhead, d_k=None,
            dim_feedforward=dim_feedforward, dropout=dropout,
            layer_norm_eps=layer_norm_eps, num_encoder_layers=num_layers,
            biased_attn_layers=biased_attn_layers,
            output_hidden_states_layers=output_hidden_states_layers
        )

    
    def forward(
        self, src: Tensor,
        pe_index: Tensor,
        src_key_padding_mask: Tensor,
        attn_bias_mask: Tensor,
    ) -> EncoderOutput:
        r"""
        Tensor shapes and notations:
            L: sequence length, B: batch size.
            Inputs:
                src: (B, L),
                pe_index: (B, L, 2),
                src_key_padding_mask: (B, L), mask for padding tokens, true 
                    for paddings.
            Outputs:
                output: (B, L, d_model)
        """
        # Follow the notes in SmilesTransformer implementation.
        src_mask = repeat(~src_key_padding_mask, 'b l -> b h l', h=1)
        src = self.pe(self.embed_src(src), pe_index)
        return self.encoder(src, src_mask, attn_bias_mask)


class SmilesTransformerFinetuneModel(pl.LightningModule):
    def __init__(
        self,
        checkpoint_path: str,
        dataset_name: str,
        change_dropout: float = 10.0,
        freeze_layers: Sequence[int] = None,
        d_model: int = 512,
        learning_rate: float = 3e-4,
        weight_decay: float = 1e-2,
    ):
        super(SmilesTransformerFinetuneModel, self).__init__()
        self.save_hyperparameters()
        print(checkpoint_path)
        dist_model = SmilesTransformerDistillationModel.load_from_checkpoint(checkpoint_path)
        self.transformer = dist_model.student_transformer
        self.transformer.output_hidden_states_layers = None

        dataset_dict = get_smiles_dataset_info(dataset_name)
        self.fclayer = nn.Linear(d_model, dataset_dict['output_dim'])
        self.loss_fn = dataset_dict['loss_fn']
        self.metric = dataset_dict['metric']
        self.evaluator = dataset_dict['evaluator']

        if change_dropout < 1.0:
            def change_module_dropout(m):
                for name, children in m.named_children():
                    if isinstance(children, nn.Dropout):
                        children.p = change_dropout
                        print('dropout rate changed on ', name)
                    change_module_dropout(children)
        
            change_module_dropout(self.transformer)
        
        trainable_params = [self.fclayer.parameters()]
        if freeze_layers is not None:
            trainable_params.append(self.transformer.embed_src.parameters())
            trainable_params.append(self.transformer.pe.parameters())
            encoder = self.transformer.encoder
            for l in range(len(encoder)):
                if l not in freeze_layers:
                    trainable_params.append(encoder.layers[l].parameters())
        else:
            trainable_params.append(self.transformer.parameters())
        print(trainable_params)
        self.trainable_params = list(chain(*trainable_params))
        for p in self.trainable_params:
            print(p.shape)
        
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
    
    def configure_optimizers(self):
        return AdamW(self.trainable_params, lr=self.learning_rate,
                     weight_decay=self.weight_decay)

    def forward(self, batched_data: BatchedSmilesData):
        output = self.transformer(
            src=batched_data.x,
            pe_index=batched_data.pe_index,
            src_key_padding_mask=batched_data.pad_mask,
            attn_bias_mask=batched_data.atom_token_attn_mask
        )
        y_pred = self.fclayer(output.feature_pred)
        return y_pred
    
    def training_step(self, batched_data: BatchedSmilesData, batch_idx: int):
        y_pred = self(batched_data)
        y = batched_data.y
        mask = ~torch.isnan(y)
        loss = self.loss_fn(y_pred[mask], y[mask])
        loss_dict = {'train_loss': loss}
        self.log('metrics', loss_dict, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batched_data: BatchedSmilesData, batch_idx: int):
        y_pred = self(batched_data)
        y_true = batched_data.y
        return {
            'y_pred': y_pred,
            'y_true': y_true
        }

    def validation_epoch_end(self, outputs):
        y_pred = torch.cat([output['y_pred'] for output in outputs])
        y_true = torch.cat([output['y_true'] for output in outputs])
        val_dict = {'y_pred': y_pred, 'y_true': y_true}
        val_loss = self.evaluator.eval(val_dict)[self.metric]
        self.log('valid_{}'.format(self.metric), val_loss)

    def test_step(self, batched_data: BatchedSmilesData, batch_idx: int):
        y_pred = self(batched_data)
        y_true = batched_data.y
        return {
            'y_pred': y_pred,
            'y_true': y_true
        }

    def test_epoch_end(self, outputs):
        y_pred = torch.cat([output['y_pred'] for output in outputs])
        y_true = torch.cat([output['y_true'] for output in outputs])
        val_dict = {'y_pred': y_pred, 'y_true': y_true}
        val_loss = self.evaluator.eval(val_dict)[self.metric]
        self.log('test_{}'.format(self.metric), val_loss)


class SmilesTransformerDistillationModel(pl.LightningModule):
    def __init__(
        self,
        dataset_name: str,
        # model parameters
        vocab_size: int = 40,
        d_model: int = 256,
        nhead: int = 8,
        dim_feedforward: int = 256,
        dropout: float = 0.1,
        num_layers: int = 6, 
        max_len: int = 100,
        pe_type: str = 'learnable',
        pe_scale_factor: float = 1.0,
        # use graphormer as teacher model
        teacher_save_path: str = '',
        # feature distillation parameters
        feat_dist_layers_s: Sequence[int] = None,
        feat_dist_layers_t: Sequence[int] = None,
        feat_dist_loss_weight: float = 0.0,
        # attention weight distillation parameters
        attnw_dist_layers_s: Sequence[int] = None,
        attnw_dist_layers_t: Sequence[int] = None,
        attnw_dist_loss_weight: float = 0.0,
        # warm up settings
        warmup_epochs: int = -1,
        warmup_task_loss_weight: float = 0.0,
        warmup_feat_dist_loss_weight: float = 0.0,
        warmup_attnw_dist_loss_weight: float = 0.0,
        # optimizer parameters
        learning_rate: float = 3e-4,
        weight_decay: float = 1e-2,
    ):
        super(SmilesTransformerDistillationModel, self).__init__()
        self.save_hyperparameters()
        
        self.teacher = TeacherModel(
            checkpoint_dir=teacher_save_path,
            feat_dist_layers=feat_dist_layers_t,
            attnw_dist_layers=attnw_dist_layers_t,
        )
        self.student_transformer = SmilesTransformer(
            vocab_size=vocab_size, d_model=d_model, nhead=nhead,
            dim_feedforward=dim_feedforward, dropout=dropout,
            num_layers=num_layers, max_len=max_len, pe_type=pe_type,
            pe_scale_factor=pe_scale_factor,
            biased_attn_layers=attnw_dist_layers_s,
            output_hidden_states_layers=feat_dist_layers_s,
        )

        self.dataset_name = dataset_name
        dataset_dict = get_smiles_dataset_info(dataset_name)

        self.fclayer = nn.Linear(d_model, 1)
        self.loss_fn = dataset_dict['loss_fn']
        self.metric = dataset_dict['metric']
        self.evaluator = dataset_dict['evaluator']
        parameters = [self.student_transformer.parameters(), self.fclayer.parameters()]

        self.feat_dist_transform = None
        if d_model != self.teacher.hidden_dim:
            self.feat_dist_transform = nn.Linear(d_model, self.teacher.hidden_dim, bias=False)
            parameters.append(self.feat_dist_transform.parameters())
        self.attnw_dist_transform = None
        if nhead != self.teacher.nhead:
            self.attnw_dist_transform = nn.Linear(nhead, self.teacher.nhead, bias=False)
            parameters.append(self.attnw_dist_transform.parameters())
        
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.learnable_parameters = list(chain(*parameters))

        for p in self.learnable_parameters:
            print(p.shape)
            if p.dim() > 1:
                nn.init.xavier_normal_(p)

        self.loss_fn = DistillationLoss(
            task_loss_fn=dataset_dict['loss_fn'],
            task_loss_weight=1.0,
            feat_dist_loss_weight=feat_dist_loss_weight,
            attnw_dist_loss_weight=attnw_dist_loss_weight,
            warmup_epochs=warmup_epochs,
            warmup_task_loss_weight=warmup_task_loss_weight,
            warmup_feat_dist_loss_weight=warmup_feat_dist_loss_weight,
            warmup_attnw_dist_loss_weight=warmup_attnw_dist_loss_weight
        )

    def configure_optimizers(self):
        return AdamW(self.learnable_parameters, lr=self.learning_rate, 
                     weight_decay=self.weight_decay)

    def forward(self, batched_data: BatchedSmilesData):
        output = self.student_transformer(
            src=batched_data.x,
            pe_index=batched_data.pe_index,
            src_key_padding_mask=batched_data.pad_mask,
            attn_bias_mask=batched_data.atom_token_attn_mask
        )
        output.y_pred = self.fclayer(output.feature_pred)
        return output

    def training_step(self, batched_data: BatchedDistillationData, batch_idx: int):
        smiles_data = batched_data.smiles_data
        graph_data = batched_data.graph_data
        student_output = self(smiles_data)
        teacher_output = self.teacher(graph_data)
        loss, loss_dict = self.loss_fn(
            student_data=student_output,
            teacher_data=teacher_output,
            smiles_data=smiles_data,
            current_epoch=self.current_epoch,
            feat_dist_transform=self.feat_dist_transform,
            attnw_dist_transform=self.attnw_dist_transform
        )
        self.log('metrics', loss_dict, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batched_data: BatchedDistillationData, batch_idx: int):
        smiles_data = batched_data.smiles_data
        output = self(smiles_data)
        y_pred = output.y_pred
        y_true = smiles_data.y
        return {
            'y_pred': y_pred,
            'y_true': y_true
        }

    def validation_epoch_end(self, outputs):
        y_pred = torch.cat([output['y_pred'] for output in outputs])
        y_true = torch.cat([output['y_true'] for output in outputs])
        val_dict = {'y_pred': y_pred, 'y_true': y_true}
        val_loss = self.evaluator.eval(val_dict)[self.metric]
        self.log('valid_{}'.format(self.metric), val_loss)

    def test_step(self, batched_data: BatchedDistillationData, batch_idx: int):
        smiles_data = batched_data.smiles_data
        output = self(smiles_data)
        y_pred = output.y_pred
        y_true = smiles_data.y
        return {
            'y_pred': y_pred,
            'y_true': y_true
        }

    def test_epoch_end(self, outputs):
        y_pred = torch.cat([output['y_pred'] for output in outputs])
        y_true = torch.cat([output['y_true'] for output in outputs])
        test_dict = {'y_pred': y_pred, 'y_true': y_true}
        test_loss = self.evaluator.eval(test_dict)[self.metric]
        self.log('test_{}'.format(self.metric), test_loss)
