import einops
import torch
import torch.nn.functional as F

from models import model_from_kwargs
from models.base.composite_model_base import CompositeModelBase
from utils.factory import create


class LagrangianGnsModel(CompositeModelBase):
    def __init__(
            self,
            encode_process_decode,
            conditioner=None,
            **kwargs,
    ):
        super().__init__(**kwargs)
        common_kwargs = dict(
            update_counter=self.update_counter,
            path_provider=self.path_provider,
            dynamic_ctx=self.dynamic_ctx,
            static_ctx=self.static_ctx,
            data_container=self.data_container,
        )
        # timestep embed
        self.conditioner = create(
            conditioner,
            model_from_kwargs,
            **common_kwargs,
            input_shape=self.input_shape,
        )
        # set static_ctx["dim"]
        if self.conditioner is not None:
            self.static_ctx["dim"] = self.conditioner.dim
        else:
            self.static_ctx["dim"] = None
        # encoder process decoder all in one :)
        self.encode_process_decode = create(
            encode_process_decode,
            model_from_kwargs,
            **common_kwargs
        )
        # Box for PBC 
        self.box = self.data_container.get_dataset().box

    @property
    def submodels(self):
        return dict(
            **(dict(conditioner_encoder=self.conditioner) if self.conditioner is not None else {}),
            encode_process_decode=self.encode_process_decode
        )

    # noinspection PyMethodOverriding
    def forward(
            self,
            x,
            timestep,
            curr_pos,
            curr_pos_decode,
            prev_pos_decode,
            edge_index,
            edge_features,
            batch_idx,
            unbatch_idx=None,
            unbatch_select=None,
            reconstruct_prev_a=False
    ):
        outputs = {}

        outputs = self.encode_process_decode(x=x, edge_index=edge_index.T, edge_features=edge_features)

        return outputs

    @torch.no_grad()
    def rollout_timing(
            self,
            x,
            timestep,
            curr_pos,
            edge_index,
            edge_features,
            batch_idx,
            unbatch_idx=None,
            unbatch_select=None,
            full_rollout=False,
            rollout_length=20
    ):
        current_vel = x[:,-2:]
        for i in range(rollout_length):
            a_hat = self.encode_process_decode(x=x, edge_index=edge_index.T, edge_features=edge_features)
        return a_hat