import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import os

from src.data import utils
from src.models.transformer_model import GraphTransformer

from sklearn.metrics import roc_auc_score
from torch_geometric.data import Data
from torch_geometric.data.collate import collate
from torch_geometric.utils import scatter

from pdb import set_trace


def remove_dummy_nodes(data):
    bs = len(data.batch.unique())

    new_data = []
    for i in range(bs):
        mask = data.batch == i
        min_idx = data.ptr[i]
        max_idx = data.ptr[i + 1]

        # Removing dummy atoms
        p_x = data.p_x[mask]
        p_x_without_dummy_nodes = p_x[p_x[:, -1] == 0]

        # Getting edge index enumerated from 0 (not like it was in batch)
        p_edge_i_idx = torch.where((min_idx <= data.p_edge_index[0]) & (data.p_edge_index[0] < max_idx))[0]
        p_edge_j_idx = torch.where((min_idx <= data.p_edge_index[1]) & (data.p_edge_index[1] < max_idx))[0]
        assert torch.all(p_edge_i_idx == p_edge_j_idx)

        p_edge_index_i = data.p_edge_index[0][p_edge_i_idx] - min_idx
        p_edge_index_j = data.p_edge_index[1][p_edge_j_idx] - min_idx

        # Checking that edges do not connect dummy nodes
        dummy_nodes_idx = torch.where(p_x[:, -1].bool())[0]
        for dummy_node_idx in dummy_nodes_idx:
            assert dummy_node_idx not in p_edge_index_i
            assert dummy_node_idx not in p_edge_index_j

        # Mapping remaining node indices to the continious range from 0 to N
        remaining_nodes_idx = torch.where(p_x[:, -1] == 0)[0]
        mapping = torch.zeros(torch.max(remaining_nodes_idx) + 1, device=p_x.device)
        for j, idx in enumerate(remaining_nodes_idx):
            mapping[idx] = j

        # Renumbering indices in edge_index
        p_edge_index_i = mapping[p_edge_index_i].long()
        p_edge_index_j = mapping[p_edge_index_j].long()

        p_edge_index = torch.stack([p_edge_index_i, p_edge_index_j], dim=0)
        p_edge_attr = data.p_edge_attr[p_edge_i_idx]
        p_smiles = data.p_smiles[i]

        batch = torch.ones(len(p_x_without_dummy_nodes), dtype=data.batch.dtype, device=data.batch.device) * i
        new_data.append(Data(
            p_x=p_x_without_dummy_nodes, p_edge_index=p_edge_index, p_edge_attr=p_edge_attr,
            p_smiles=p_smiles, idx=data.idx[i], batch=batch
        ))

    data, _, _ = collate(Data, new_data)
    return data


class BinarySizeModel(pl.LightningModule):
    def __init__(
            self,
            experiment_name,
            checkpoints_dir,
            lr,
            weight_decay,
            n_layers,
            hidden_mlp_dims,
            hidden_dims,
            dataset_infos,
            extra_features,
            domain_features,
            log_every_steps,
    ):
        super().__init__()

        self.model_dtype = torch.float32

        self.name = experiment_name
        self.checkpoints_dir = checkpoints_dir
        self.lr = lr
        self.weight_decay = weight_decay
        self.extra_features = extra_features
        self.domain_features = domain_features

        output_dims = {
            'X': 0,
            'E': 0,
            'y': 1,
        }
        self.model = GraphTransformer(
            n_layers=n_layers,
            input_dims=dataset_infos.size_input_dims,
            hidden_mlp_dims=hidden_mlp_dims,
            hidden_dims=hidden_dims,
            output_dims=output_dims,
            act_fn_in=nn.ReLU(),
            act_fn_out=nn.ReLU()
        )

        self.save_hyperparameters(ignore=[dataset_infos])

        self.start_epoch_time = None
        self.train_iterations = None
        self.val_iterations = None

        self.log_every_steps = log_every_steps
        self.val_counter = 0

    def configure_optimizers(self):
        return torch.optim.AdamW(
            params=self.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
            amsgrad=True,
        )

    def process_and_forward(self, data):
        product, node_mask = utils.to_dense(data.p_x, data.p_edge_index, data.p_edge_attr, data.batch)
        product = product.mask(node_mask)

        input_data = {
            'X_t': product.X,
            'E_t': product.E,
            'y_t': product.y,
            'node_mask': node_mask,
        }
        extra_data = self.compute_extra_data(input_data)
        input_data['X_t'] = input_data['X_t'][..., :-1]

        pred = self.forward(input_data, extra_data, node_mask)
        pred_y = torch.sigmoid(pred.y.squeeze(-1))

        return pred_y

    def training_step(self, data, i):
        true_y = scatter(data.p_x[:, -1], data.batch, reduce='sum')
        true_y = (true_y.squeeze() == 1).float()

        data = remove_dummy_nodes(data)
        pred_y = self.process_and_forward(data)

        assert true_y.shape == pred_y.shape

        loss = F.binary_cross_entropy(pred_y, true_y)
        if i % self.log_every_steps == 0:
            self.log('train_size_BCE', loss.detach())

        return {'loss': loss}

    def validation_step(self, data, i):
        true_y = scatter(data.p_x[:, -1], data.batch, reduce='sum')
        true_y = (true_y.squeeze() == 1).float()

        data = remove_dummy_nodes(data)
        pred_y = self.process_and_forward(data)

        assert true_y.shape == pred_y.shape

        loss = F.binary_cross_entropy(pred_y, true_y)
        accuracy = ((pred_y > 0.5).int() == true_y.int()).sum() / len(pred_y)
        roc_auc = roc_auc_score(
            y_true=true_y.squeeze().detach().cpu().numpy(),
            y_score=pred_y.squeeze().detach().cpu().numpy(),
        )

        if i % self.log_every_steps == 0:
            self.log('val_size_BCE', loss.detach())
            self.log('val_size_acc', accuracy.detach())
            self.log('val_size_acc_roc_auc', roc_auc)

        return {'loss': loss}

    def validation_epoch_end(self, outs):
        self.trainer.save_checkpoint(os.path.join(self.checkpoints_dir, 'last.ckpt'))

    def forward(self, noisy_data, extra_data, node_mask):
        X = torch.cat((noisy_data['X_t'], extra_data.X), dim=2).float()
        E = torch.cat((noisy_data['E_t'], extra_data.E), dim=3).float()
        y = torch.hstack((noisy_data['y_t'], extra_data.y)).float()
        return self.model(X, E, y, node_mask)

    def compute_extra_data(self, noisy_data):
        extra_features = self.extra_features(noisy_data)
        extra_molecular_features = self.domain_features(noisy_data)

        extra_X = torch.cat((extra_features.X, extra_molecular_features.X), dim=-1)
        extra_E = torch.cat((extra_features.E, extra_molecular_features.E), dim=-1)
        extra_y = torch.cat((extra_features.y, extra_molecular_features.y), dim=-1)

        return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y)


class MultilabelSizeModel(pl.LightningModule):
    def __init__(
            self,
            experiment_name,
            checkpoints_dir,
            lr,
            weight_decay,
            n_layers,
            hidden_mlp_dims,
            hidden_dims,
            dataset_infos,
            extra_features,
            domain_features,
            log_every_steps,
    ):
        super().__init__()

        self.model_dtype = torch.float32

        self.name = experiment_name
        self.checkpoints_dir = checkpoints_dir
        self.lr = lr
        self.weight_decay = weight_decay
        self.extra_features = extra_features
        self.domain_features = domain_features
        self.dataset_infos = dataset_infos

        output_dims = {
            'X': 0,
            'E': 0,
            'y': dataset_infos.max_n_dummy_nodes + 1,
        }
        self.model = GraphTransformer(
            n_layers=n_layers,
            input_dims=dataset_infos.size_input_dims,
            hidden_mlp_dims=hidden_mlp_dims,
            hidden_dims=hidden_dims,
            output_dims=output_dims,
            act_fn_in=nn.ReLU(),
            act_fn_out=nn.ReLU(),
            addition=False,
        )

        self.save_hyperparameters(ignore=[dataset_infos])

        self.start_epoch_time = None
        self.train_iterations = None
        self.val_iterations = None

        self.log_every_steps = log_every_steps
        self.val_counter = 0

    def configure_optimizers(self):
        return torch.optim.AdamW(
            params=self.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
            amsgrad=True,
        )

    def process_and_forward(self, data):
        product, node_mask = utils.to_dense(data.p_x, data.p_edge_index, data.p_edge_attr, data.batch)
        product = product.mask(node_mask)

        input_data = {
            'X_t': product.X,
            'E_t': product.E,
            'y_t': product.y,
            'node_mask': node_mask,
        }
        extra_data = self.compute_extra_data(input_data)
        input_data['X_t'] = input_data['X_t'][..., :-1]

        pred = self.forward(input_data, extra_data, node_mask)
        pred_y = pred.y.squeeze()
        pred_y = torch.softmax(pred_y, dim=-1)

        return pred_y

    def training_step(self, data, i):
        true_y = scatter(data.p_x[:, -1], data.batch, reduce='sum').long()

        data = remove_dummy_nodes(data)
        pred_y = self.process_and_forward(data)

        loss = F.cross_entropy(pred_y, true_y)
        if i % self.log_every_steps == 0:
            self.log('train_size_BCE', loss.detach())

        return {'loss': loss}

    def validation_step(self, data, i):
        true_y = scatter(data.p_x[:, -1], data.batch, reduce='sum').long()

        data = remove_dummy_nodes(data)
        pred_y = self.process_and_forward(data)

        loss = F.cross_entropy(pred_y, true_y)
        accuracy = (pred_y.argmax(dim=-1) == true_y.int()).sum() / len(pred_y)
        if i % self.log_every_steps == 0:
            self.log('val_size_BCE', loss.detach())
            self.log('val_size_acc', accuracy.detach())

        return {'loss': loss}

    def validation_epoch_end(self, outs):
        self.trainer.save_checkpoint(os.path.join(self.checkpoints_dir, 'last.ckpt'))

    def forward(self, noisy_data, extra_data, node_mask):
        X = torch.cat((noisy_data['X_t'], extra_data.X), dim=2).float()
        E = torch.cat((noisy_data['E_t'], extra_data.E), dim=3).float()
        y = torch.hstack((noisy_data['y_t'], extra_data.y)).float()
        return self.model(X, E, y, node_mask)

    def compute_extra_data(self, noisy_data):
        extra_features = self.extra_features(noisy_data)
        extra_molecular_features = self.domain_features(noisy_data)

        extra_X = torch.cat((extra_features.X, extra_molecular_features.X), dim=-1)
        extra_E = torch.cat((extra_features.E, extra_molecular_features.E), dim=-1)
        extra_y = torch.cat((extra_features.y, extra_molecular_features.y), dim=-1)

        return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y)
