from functools import reduce

import torch
from torch import nn

from group_discovery.geometry_3d import robust_blogm_so3 as blogm

# helpers functions
expm = torch.linalg.matrix_exp


# Model
class OracleFlow(nn.Module):
    def __init__(self, prior_dist, device):
        super().__init__()
        self.device = device
        self.prior_dist = prior_dist

    @torch.no_grad()
    def sample_all(self, x_1, n_steps, return_transform=False, x_0=None):
        if x_0 is None:
            B = x_1.shape[0]
            rand_transform = self.prior_dist.sample((B,)).to(self.device)

            x_t = x_1 @ rand_transform.transpose(-2, -1)
        else:
            B = x_0.shape[0]
            x_t = x_0.clone().to(self.device)

        x_t_all = [x_t]
        transforms = torch.zeros(
            n_steps + 1, B, *self.prior_dist.event_shape, device=self.device
        )
        t_values = torch.linspace(0.0, 1.0, n_steps + 1, device=self.device)

        for i in range(n_steps):
            t = t_values[i]
            delta_t = t_values[i + 1] - t_values[i]

            A = blogm(torch.linalg.lstsq(x_0, x_1, driver="gels").solution)

            tf = expm(delta_t * A)
            x_t = x_t @ tf

            x_t_all.append(x_t)
            transforms[i] = tf

        x_t_all = torch.stack(x_t_all, dim=0)  # [n_steps + 1, batch, 2, 2]
        # Last transform is identity
        transforms[-1] = (
            torch.eye(x_1.shape[-1], device=self.device).unsqueeze(0).expand(B, -1, -1)
        )

        if return_transform:
            if x_0 is None:
                return x_t_all, rand_transform, transforms
            else:
                return x_t_all, transforms
        else:
            return x_t_all

    @torch.no_grad()
    def sample(self, x_1, n_steps, return_transform=False, x_0=None):
        if return_transform:
            if x_0 is None:
                x_t_all, orig_tf, tfs = self.sample_all(
                    x_1, n_steps, return_transform=return_transform, x_0=x_0
                )
                return x_t_all[-1], orig_tf, reduce(torch.bmm, tfs)
            else:
                x_t_all, tfs = self.sample_all(
                    x_1, n_steps, return_transform=return_transform, x_0=x_0
                )
                return x_t_all[-1], reduce(torch.bmm, tfs)
        else:
            x_t_all = self.sample_all(
                x_1, n_steps, return_transform=return_transform, x_0=x_0
            )
            return x_t_all[-1]
