import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from omegaconf import open_dict, OmegaConf

from models.transformer_model import GraphTransformer
import utils.utils as utils



def add_new_key(cfg):
    OmegaConf.set_struct(cfg.model, True)
    if not hasattr(cfg.model, "level_ar"):
        with open_dict(cfg.model):
            cfg.model.level_ar = False

    if not hasattr(cfg.train, 'dropout'):
        with open_dict(cfg.train):
            cfg.train.dropout = 0.1

    OmegaConf.set_struct(cfg.general, True)
    if not hasattr(cfg.general, 'sample_num'):
        with open_dict(cfg.general):
            cfg.general.sample_num = 1         


class MoleculeTargetPrediction(pl.LightningModule):
    def __init__(self, cfg, dataset_infos, extra_features):
        super().__init__()
        self.save_hyperparameters()
        input_dims = copy.deepcopy(dataset_infos.input_dims)
        input_dims['X'] -= 1
        input_dims['y'] -= 2
        output_dims = {"X": 0, "E": 0, "y": 1}

        add_new_key(cfg)
        self.cfg = cfg
        self.name = cfg.general.name
        self.model_dtype = torch.float32

        self.Xdim = input_dims['X']
        self.Edim = input_dims['E']
        self.ydim = input_dims['y']

        self.dataset_info = dataset_infos

        self.extra_features = extra_features

        self.model = GraphTransformer(n_layers=cfg.model.n_layers,
                                    input_dims=input_dims,
                                    hidden_mlp_dims=cfg.model.hidden_mlp_dims,
                                    hidden_dims=cfg.model.hidden_dims,
                                    output_dims=output_dims,
                                    act_fn_in=nn.ReLU(),
                                    act_fn_out=nn.ReLU(),
                                    dropout = cfg.train.dropout,
                                    node_vocab_size = dataset_infos.node_types.shape[0])
        

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.cfg.train.lr, amsgrad=True,
                                 weight_decay=self.cfg.train.weight_decay)
        return optimizer

    def shared_step(self, data, split = "train"):
        level = data.level
        dense_data, node_mask, level = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch, level)
        dense_data = dense_data.mask(node_mask)
        X, E = dense_data.X, dense_data.E
        y = data.y
        level = level / (level.max(dim = -1, keepdim = True)[0] + 1e-7)

        data = {'X_t': X, 'E_t': E, 'y_t': y.new_zeros(y.shape[0], 0), 'node_mask': node_mask, "node_level": level.unsqueeze(-1)}
        extra_data = self.compute_extra_data(data)
        
        pred_y = self.forward(data, extra_data, node_mask)
        loss = F.mse_loss(pred_y, y)
        metrics = {f'{split}/mse_loss': loss}
        self.log_dict(metrics, on_epoch = True, batch_size = X.shape[0])

        return {'loss': loss}
    
    def training_step(self, data, i):
        return self.shared_step(data, "train")
    
    def validation_step(self, data, i):
        return self.shared_step(data, "val")
    
    def test_step(self, data, i):
        return self.shared_step(data, "test")

    def forward(self, noisy_data, extra_data, node_mask):
        X = torch.cat((noisy_data['X_t'], extra_data.X), dim=2).float()
        E = torch.cat((noisy_data['E_t'], extra_data.E), dim=3).float()
        y = torch.hstack((noisy_data['y_t'], extra_data.y)).float()
        return self.model(X, E, y, node_mask).y

    def compute_extra_data(self, data):
        """ At every training step (after adding noise) and step in sampling, compute extra information and append to
            the network input. """

        extra_features = self.extra_features(data)

        extra_X = torch.cat((extra_features.X, data["node_level"]), dim=-1)
        extra_E = extra_features.E
        extra_y = extra_features.y

        return utils.PlaceHolder(X = extra_X, E = extra_E, y = extra_y)
