import torch
import torch.nn as nn
import math
from dl.src.constants import BASE_EN_DE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def comp_to_real(real, comp):
    output = -2 * comp / ((real - 1) ** 2 + comp**2)
    return output


class CMCS_GT_VAE(nn.Module):
    def __init__(self, config):
        super(CMCS_GT_VAE, self).__init__()
        encoder, decoder = BASE_EN_DE[config.dataset]
        self.encoder = encoder(config)
        self.decoder = decoder(config)
        self.n = config.nth_root
        self.t_iter = config.t_iter
        self.iter = 0.0


        temp_grid = torch.linspace(
            start=-math.pi, end=math.pi, steps=self.n, requires_grad=False
        )[1:-1]

        self.ground = nn.Parameter(temp_grid)  # .to(device)
        self.scale = nn.Parameter(
            4.5 * torch.ones(size=(config.latent_dim,), requires_grad=True)
        )
        self.cp = nn.Parameter(
            torch.ones(size=(config.latent_dim,), requires_grad=True)
        )

        if config.dataset == "dsprites":
            self.active_dim = 5
        elif config.dataset == "shapes3d":
            self.active_dim = 6
        elif "mpi3d" in config.dataset:
            self.active_dim = 7

    def forward(self, input, loss_fn, target=None, class_label=None):
        g_scale = torch.sigmoid(self.scale)
        if target != None:
            batch = target.size(0)
            scale = self.n / 100.0

            self.iter += 1.0
            # input: [2B, D]
            # target: [B, D]
            encoder_output = self.encoder(input)

            # code for reconstruction loss
            theta_mu = self.real_to_theta(encoder_output[0])
            code = self.select_code(theta_mu)
            trans_code_by_g = self.group_action(code, target, scale)
            trans_theta = self.group_action(theta_mu, target, scale)
            trans_code = self.select_code(trans_theta)

            canonical_point = self.select_code(
                self.group_action(self.cp, class_label, scale)
            )
            canonical_point = torch.cat(
                [canonical_point[batch // 2 :], canonical_point[: batch // 2]], dim=0
            )
            canonical_loss = (
                torch.abs(
                    canonical_point[: batch // 2, : self.active_dim]
                    - code[: batch // 2, : self.active_dim]
                )
            ).sum(dim=-1).mean() + (
                torch.abs(
                    canonical_point[batch // 2 :, : self.active_dim]
                    - trans_code_by_g[batch // 2 :, : self.active_dim].detach()
                )
            ).sum(
                dim=-1
            ).mean()

            codes = torch.cat([code, trans_code], dim=0)  # * g_scale  # [2B, D]
            decoder_output = self.decoder(codes)
            outputs = (encoder_output,) + (decoder_output,)

            loss = self.loss(input, outputs, loss_fn)
            # code_loss:
            # 1. z_1, c_1 (O)
            # 2. z_2, c_2^\prime (O)
            # 3. z_1^\prime, c_1 (X)
            # 4. z_2^\prime, c_2^\prime (X)
            code_loss = (
                (
                    (
                        theta_mu[: batch // 2, : self.active_dim]
                        - code[: batch // 2, : self.active_dim]
                    )
                    ** 2
                )
                .sum(dim=-1)
                .mean()
                + (
                    (
                        theta_mu[batch // 2 :, : self.active_dim]
                        - trans_code_by_g[batch // 2 :, : self.active_dim].detach()
                    )
                    ** 2
                )
                .sum(dim=-1)
                .mean()
            )

            loss["obj"]["code_loss"] = code_loss
            loss["obj"]["canonical_loss"] = canonical_loss
            loss["obj"]["ground_scale"] = g_scale.mean()

            loss = (loss,) + (encoder_output,) + (decoder_output,)

            return loss

        else:
            encoder_output = self.encoder(input)
            theta_mu = self.real_to_theta(encoder_output[0])
            code = self.select_code(theta_mu)  # * g_scale
            decoder_output = self.decoder(code)
            loss = (code,) + (decoder_output,)
            return loss

    def batch_cos(self, canonical):
        norm = torch.norm(canonical, dim=-1, keepdim=True)
        output = torch.mm(canonical, canonical.transpose(-1, -2))
        norm = torch.mm(norm, norm.transpose(-1, -2))
        cos = torch.abs(output / norm)
        return cos

    def loss(self, input, outputs, loss_fn):
        result = {"elbo": {}, "obj": {}, "id": {}}

        reconsted_images = outputs[1][0]
        batch = reconsted_images.size(0) // 2
        z, mu, logvar = (
            outputs[0][0].squeeze(),
            outputs[0][1].squeeze(),
            outputs[0][2].squeeze(),
        )
        kld_err = torch.mean(
            -0.5
            * torch.sum(
                1 + logvar - mu**2 - logvar.exp(),
                dim=-1,
            )
        )
        reconst_err = loss_fn(reconsted_images[:batch], input) / batch
        dec_equiv_loss = loss_fn(reconsted_images[batch:], input) / batch

        result["obj"]["reconst"] = reconst_err  # .unsqueeze(0)
        result["obj"]["kld"] = kld_err  # .unsqueeze(0)
        result["obj"]["dec_equiv"] = dec_equiv_loss

        return result

    def enc_equiv_loss(self, tz, mu, logvar):
        loss = (
            (
                torch.log(math.sqrt(2 * math.pi) * torch.exp(0.5 * logvar))
                + 0.5 * ((tz.detach() - mu) / torch.exp(0.5 * logvar) + 1e-9) ** 2
            )
            .sum(dim=-1)
            .mean()
        )
        return loss

    def select_code(self, z):

        z = z.unsqueeze(-1)  # [B, D, 1]
        diff = torch.abs(z - self.ground)
        _, indices = torch.min(diff, dim=-1)
        output = z.squeeze() + self.ground[indices] - z.squeeze().detach()

        return output

    def real_to_theta(self, M):

        acos_func = Acosine.apply
        r, c = self.real_to_comp(M)
        theta = acos_func(r) - math.pi
        theta = torch.where(c >= 0, theta, -theta)  # * scale
        return theta

    # conformal map from real number to complex number
    def real_to_comp(self, M):
        real = (M**2 - 1) / (M**2 + 1)
        comp = -2 * M / (M**2 + 1)
        return real, comp

    # conformal map from complex number to real number
    def comp_to_real(self, real, comp):
        output = -2 * comp / ((real - 1) ** 2 + comp**2)
        return output

    def group_action(self, theta, target, scale):
        batch = theta.size(0)

        dtheta = scale * target * 2 * math.pi / self.n  # * grid_scale
        output = theta + dtheta
        output = torch.where(output > math.pi, output - 2 * math.pi, output)
        output = torch.where(output < -math.pi, output + 2 * math.pi, output)

        new_output = torch.cat([output[batch // 2 :], output[: batch // 2]], dim=0)
        return new_output

    def init_weights(self):
        for n, p in self.named_parameters():
            if "ground" in n or "grid" in n:  #  or "scale" in n:
                continue
            else:
                if p.data.ndimension() >= 2:
                    nn.init.xavier_uniform_(p.data)
                else:
                    nn.init.zeros_(p.data)

    def freeze(self):
        for n, p in self.named_parameters():
            if "ground" in n or "label_code" in n:
                p.requires_grad = False
            else:
                p.requires_grad = True


# manuall acosine module to prevent gradient explosion
class Acosine(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.acos(input)

    @staticmethod
    def backward(ctx, grad_output):
        (x,) = ctx.saved_tensors
        acos_grad = -1 / (torch.sqrt(1 - x**2) + 1e-9)
        # acos_grad = torch.where(acos_grad > 1e2, acos_grad * 1e-1, acos_grad)
        grad_output = grad_output * acos_grad
        return grad_output
