import torch
import torch.nn.functional as F
from torch import Tensor, nn

from group_discovery.geometry_2d import matrix_to_angle, wrap_angle

inv = torch.linalg.inv
expm = torch.linalg.matrix_exp


# Model
class Canonicalization(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim, prior_dist, device):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )
        self.device = device
        self.prior_dist = prior_dist

    def forward(self, x: Tensor) -> Tensor:
        x = x.flatten(1)

        out = self.net(x)

        out = out.view(-1, 2, 2)

        return out

    def train_net(self, train_loader, optimizer):
        self.train()
        results = {"mse_loss": 0.0}

        for x in train_loader:
            x = x.to(self.device)

            rand_matrix1 = (
                self.prior_dist.sample((x.shape[0],)).to(self.device).transpose(-1, 2)
            )
            x_rand1 = x @ rand_matrix1

            x_rand_transform1 = self(x_rand1)

            rand_matrix2 = (
                self.prior_dist.sample((x.shape[0],)).to(self.device).transpose(-1, -2)
            )
            x_rand2 = x @ rand_matrix2

            x_rand_transform2 = self(x_rand2)

            composed_transform = (
                rand_matrix1
                @ expm(-x_rand_transform1)
                @ expm(x_rand_transform2)
                @ inv(rand_matrix2)
            )

            x_canonicalized = x @ composed_transform

            optimizer.zero_grad()

            mse_loss = F.mse_loss(x_canonicalized, x)
            results["mse_loss"] += mse_loss.item()

            loss = mse_loss

            loss.backward()
            optimizer.step()

        for k, v in results.items():
            results[k] /= len(train_loader)

        return results

    @torch.no_grad()
    def eval_net(self, test_loader):
        self.eval()
        results = {"loss": 0.0, "angle_mse": 0.0, "matrix_mse": 0.0}

        for x, gt_transform in test_loader:
            x = x.to(self.device)
            x_transform = self(x)

            gt_transform = gt_transform.to(self.device)
            gt_angle = matrix_to_angle(gt_transform)

            rand_matrix1 = (
                self.prior_dist.sample((x.shape[0],)).to(self.device).transpose(-1, 2)
            )
            x_rand1 = x @ rand_matrix1

            x_rand_transform1 = self(x_rand1)

            rand_matrix2 = (
                self.prior_dist.sample((x.shape[0],)).to(self.device).transpose(-1, -2)
            )
            x_rand2 = x @ rand_matrix2

            x_rand_transform2 = self(x_rand2)

            composed_transform = (
                rand_matrix1
                @ expm(-x_rand_transform1)
                @ expm(x_rand_transform2)
                @ inv(rand_matrix2)
            )

            x_canonicalized = x @ composed_transform

            loss = F.mse_loss(x_canonicalized, x)

            results["loss"] += loss.item()
            results["angle_mse"] += F.mse_loss(
                wrap_angle(matrix_to_angle(expm(x_transform))), wrap_angle(gt_angle)
            ).item()
            results["matrix_mse"] += F.mse_loss(expm(x_transform), gt_transform).item()

        for k, v in results.items():
            results[k] /= len(test_loader)

        return results
