from typing import Sequence, Union, Callable, Optional, Dict

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

import schnetpack as spk
import schnetpack.properties as properties
from schnetpack.diffusion.utils import scatter_mean


class DiffAtomwise(spk.atomistic.Atomwise):
    def __init__(
        self,
        n_in: int,
        n_out: int = 1,
        include_time: bool = False,
        time_head: nn.Module = None,
        time_output_key: str = "diff_step_pred",
        detach_time_head: bool = False,
        train_on_true_time: bool = False,
        output_activation: Optional[Callable] = None,
        optional_output_keys: Optional[Sequence[str]] = None,
        **kwargs
    ):

        # append time as input node
        self.include_time = include_time
        if self.include_time:
            n_in += 1

        super().__init__(n_in, n_out=n_out, **kwargs)

        self.optional_output_keys = optional_output_keys
        if optional_output_keys is not None:
            self.model_outputs = self.model_outputs + optional_output_keys

        self.time_output_key = time_output_key
        self.time_outnet = time_head

        self.output_activation = output_activation

        # predict time
        self.time_outnet = time_head
        self.time_output_key = time_output_key
        self.detach_time_head = detach_time_head
        self.train_on_true_time = train_on_true_time

    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # append time to representation
        if self.include_time:
            if (self.training and self.train_on_true_time) or self.time_outnet is None:
                t = inputs["diff_step"]
            else:
                t = self.time_outnet(inputs)[self.time_output_key]
                if self.time_outnet.aggregation_mode is not None:
                    t = t[inputs[properties.idx_m]]
                if self.detach_time_head:
                    t = t.detach()
            t = t.unsqueeze(-1)
            inputs["scalar_representation"] = torch.cat(
                (inputs["scalar_representation"], t), dim=-1
            )

        # predict atomwise contributions
        inputs = super().forward(inputs)

        if self.output_activation is not None:
            inputs[self.output_key] = self.output_activation(
                inputs[self.output_key], dim=-1
            )

        return inputs


class EqHead(nn.Module):
    def __init__(
        self,
        n_in: int,
        n_hidden: Optional[Union[int, Sequence[int]]] = None,
        n_layers: int = 2,
        activation: Callable = F.silu,
        output_key: str = "eps_pred",
        include_time: bool = False,
        time_head: nn.Module = None,
        time_output_key: str = "diff_step_pred",
        detach_time_head: bool = False,
        train_on_true_time: bool = False,
    ):
        super().__init__()
        self.include_time = include_time
        self.outnet = spk.nn.build_gated_equivariant_mlp(
            n_in=n_in,
            n_out=1,
            n_hidden=n_hidden,
            n_layers=n_layers,
            activation=activation,
            sactivation=activation,
        )
        self.output_key = output_key
        self.model_outputs = [output_key]

        # append time as scalar node
        if self.include_time:
            self.outnet[0] = spk.nn.GatedEquivariantBlock(
                n_sin=self.outnet[0].n_sin + 1,
                n_vin=self.outnet[0].n_vin,
                n_sout=self.outnet[0].n_sout,
                n_vout=self.outnet[0].n_vout,
                n_hidden=self.outnet[0].n_hidden,
                activation=activation,
                sactivation=activation,
            )

        # predict time
        self.time_outnet = time_head
        self.time_output_key = time_output_key
        self.detach_time_head = detach_time_head
        self.train_on_true_time = train_on_true_time

    def forward(self, inputs):
        l0 = inputs["scalar_representation"]
        l1 = inputs["vector_representation"]

        # append time to representation
        if self.include_time:
            if (self.training and self.train_on_true_time) or self.time_outnet is None:
                t = inputs["diff_step"]
            else:
                t = self.time_outnet(inputs)[self.time_output_key]
                if self.time_outnet.use_classification:
                    t = t.argmax(-1) / (self.time_outnet.T * 1.0)
                if self.time_outnet.aggregation_mode is not None:
                    t = t[inputs[properties.idx_m]]
                if self.detach_time_head:
                    t = t.detach()
            t = t.unsqueeze(-1)
            l0 = torch.cat((l0, t), dim=-1)

        # forward pass
        _, x = self.outnet((l0, l1))
        x = torch.squeeze(x, -1)

        inputs[self.output_key] = x
        return inputs


class TimeHead(nn.Module):
    def __init__(
        self,
        n_in: int,
        n_hidden: Optional[Union[int, Sequence[int]]] = None,
        n_layers: int = 2,
        activation: Callable = F.silu,
        time_output_key: str = "diff_step_pred",
        aggregation_mode: str = None,
        use_classification: bool = False,
        detach_representation: bool = False,
        representation: str = None,
        T: int = 1000,
    ):
        super().__init__()

        # append time as scalar node

        self.time_output_key = time_output_key
        self.use_classification = use_classification
        self.aggregation_mode = aggregation_mode
        self.model_outputs = [time_output_key]

        self.detach_representation = detach_representation

        self.representation = representation
        if self.representation is not None:
            logging.info("Loading representation from {}".format(self.representation))
            self.representation = torch.load(self.representation).representation

        # predict time

        if self.use_classification:
            self.n_out = T
        else:
            self.n_out = 1
        self.T = T

        self.time_outnet = spk.atomistic.Atomwise(
            n_in=n_in,
            n_out=self.n_out,
            n_hidden=n_hidden,
            n_layers=n_layers,
            activation=activation,
            aggregation_mode=self.aggregation_mode,
            output_key=self.time_output_key,
            per_atom_output_key=self.time_output_key,
        )

    def forward(self, inputs):

        # Use pre-trained representation
        if self.representation is not None:
            inputs = self.representation(inputs)

        # No backpropagation to the representation from time head.
        if self.detach_representation:
            det_inputs = {key: val.detach() for key, val in inputs.items()}
            det_inputs = self.time_outnet(det_inputs)
            inputs[self.time_output_key] = det_inputs[self.time_output_key]
        else:
            inputs = self.time_outnet(inputs)

        if self.aggregation_mode is None and not self.use_classification:
            inputs[self.time_output_key] = inputs[self.time_output_key].squeeze(-1)
        return inputs
