""" """

from functools import reduce

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Module

from group_discovery.utils import blogm

# helpers functions
expm = torch.linalg.matrix_exp


# Network
class MLP(Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x_t, t):
        # x_t: [B, N, D]
        # t: [B]
        B = x_t.shape[0]
        D = x_t.shape[-1]

        x_t = x_t.flatten(1)

        xin = torch.cat([x_t, t[:, None]], dim=1)

        out = self.net(xin)

        out = out.view(B, -1, D)

        return out


# Model
class Flow(Module):
    def __init__(self, net, prior_dist, device):
        super().__init__()
        self.net = net
        self.device = device
        self.prior_dist = prior_dist

    def forward(self, x_t, t):
        return self.net(x_t, t)

    @torch.no_grad()
    def step(
        self,
        x_t: torch.Tensor,
        t: torch.Tensor,
        Δt: torch.Tensor,
        return_transform=False,
    ):
        """Return x_(t + Δt) given x_t and t"""
        pred = self.net(x_t, t)

        out = x_t @ expm(Δt * pred)

        if return_transform:
            return out, expm(Δt * pred)
        else:
            return out

    @torch.no_grad()
    def sample(self, x_1, n_steps, return_transform=False):
        B = x_1.shape[0]
        rand_transforms = self.prior_dist.sample((B,)).to(self.device)
        x = x_1 @ rand_transforms
        A = []

        for i in range(n_steps):
            t = torch.full((B,), i / n_steps).to(self.device)
            if return_transform:
                x, tf = self.step(x, t, 1.0 / n_steps, return_transform=True)
                A.append(tf)
            else:
                x = self.step(x, t, 1.0 / n_steps, return_transform=False)

        if return_transform:
            A = reduce(torch.bmm, A)
            return x, A
        else:
            return x

    def train_net(self, train_loader, optimizer):
        self.train()
        total_loss = 0.0
        for x_1 in train_loader:
            x_1 = x_1.to(self.device)
            B = x_1.shape[0]
            t = torch.rand(B, device=self.device)

            # rand_transforms = gl2_algebra_dist_generators(
            #     (B,), loc=torch.zeros(4), scale=torch.Tensor([1, 0.1, 1, torch.pi])
            # ).to(self.device)
            # x_0 = x_1 @ expm(rand_transforms)

            rand_transforms = self.prior_dist.sample((B,)).to(self.device)
            x_0 = x_1 @ rand_transforms

            # x_0 @ g_t^T = x_1, A_t = blogm(g_t^T)
            # Let A_t be the target so the model learns to predict
            # the transpose of transformation matrices
            A_t = blogm(torch.linalg.lstsq(x_0, x_1, driver="gels").solution)

            x_t = x_0 @ expm(t[:, None, None] * A_t)

            optimizer.zero_grad()
            pred = self.net(x_t, t)

            loss = F.mse_loss(pred, A_t)
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

        return total_loss / len(train_loader)

    @torch.no_grad()
    def eval_net(self, test_loader):
        self.eval()
        total_loss = 0.0
        for x_1 in test_loader:
            x_1 = x_1.to(self.device)
            B = x_1.shape[0]
            t = torch.rand(B, device=self.device)

            # rand_transforms = gl2_algebra_dist_generators(
            #     (B,), loc=torch.zeros(4), scale=torch.Tensor([1, 0.1, 1, torch.pi])
            # ).to(self.device)
            # x_0 = x_1 @ expm(rand_transforms)

            rand_transforms = self.prior_dist.sample((B,)).to(self.device)
            x_0 = x_1 @ rand_transforms

            # Compute the logarithm of A where x0 @ A = x1
            A_t = blogm(torch.linalg.lstsq(x_0, x_1).solution)

            x_t = x_0 @ expm(t[:, None, None] * A_t)

            pred = self.net(x_t, t)
            loss = F.mse_loss(pred, A_t)
            total_loss += loss.item()

        return total_loss / len(test_loader)
