import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dl.src.constants import BASE_EN_DE, FACTOR_INFORM, FACTOR_CLASSIFIER

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def comp_to_real(real, comp):
    output = -2 * comp / ((real - 1) ** 2 + comp**2)
    return output


class CMCS_Super_VAE(nn.Module):
    def __init__(self, config):
        super(CMCS_Super_VAE, self).__init__()
        encoder, decoder = BASE_EN_DE[config.dataset]
        pseudo_label = FACTOR_CLASSIFIER[config.dataset]
        self.dataset = config.dataset
        self.pseudo_label = pseudo_label(config)
        self.encoder = encoder(config)
        self.decoder = decoder(config)
        self.n = config.nth_root
        self.t_iter = config.t_iter
        self.iter = 0.0


        temp_grid = torch.linspace(
            start=-math.pi, end=math.pi, steps=self.n, requires_grad=False
        )[1:-1]

        self.ground = nn.Parameter(temp_grid)  # .to(device)
        self.scale = nn.Parameter(
            4.5 * torch.ones(size=(config.latent_dim,), requires_grad=True)
        )
        self.cp = nn.Parameter(
            torch.ones(size=(config.latent_dim,), requires_grad=True)
        )

        if config.dataset == "dsprites":
            self.active_dim = 5
        elif config.dataset == "shapes3d":
            self.active_dim = 6
        elif "mpi3d" in config.dataset:
            self.active_dim = 7

    def forward(self, input, loss_fn, target=None):
        batch = input.size(0)
        scale = self.n / 100.0
        # input: [2B, D]
        # target: [2B, D]

        encoder_output = self.encoder(input)

        # set pseudo label and pseudo symmetries.
        class_label, label_logits = self.pseudo_label(input)

        # class_label = class_label - class_label.detach() + class_label.round()
        sym = class_label[batch // 2 :] - class_label[: batch // 2]
        sym = torch.where(sym >= 0, sym, 100 + sym)
        sym = torch.cat([sym, 100 - sym], dim=0)

        # code for reconstruction loss
        theta_mu = self.real_to_theta(encoder_output[0])
        code = self.select_code(theta_mu)
        trans_code_by_g = self.group_action(code, sym, scale)
        trans_theta = self.group_action(theta_mu, sym, scale)
        trans_code = self.select_code(trans_theta)

        canonical_point = self.select_code(
            self.group_action(self.cp, class_label, scale)
        )
        canonical_point = torch.cat(
            [canonical_point[batch // 2 :], canonical_point[: batch // 2]], dim=0
        )
        canonical_loss = (
            torch.abs(
                canonical_point[: batch // 2, : self.active_dim]
                - code[: batch // 2, : self.active_dim]
            )
        ).sum(dim=-1).mean() + (
            torch.abs(
                canonical_point[batch // 2 :, : self.active_dim]
                - trans_code_by_g[batch // 2 :, : self.active_dim].detach()
            )
        ).sum(
            dim=-1
        ).mean()

        codes = torch.cat([code, trans_code], dim=0)  # * g_scale  # [3B, D]
        decoder_output = self.decoder(codes)
        outputs = (encoder_output,) + (decoder_output,)
        loss = self.loss(input, outputs, loss_fn)

        # code_loss:
        # 1. z_1, c_1 (O)
        # 2. z_2, c_2^\prime (O)
        # 3. z_1^\prime, c_1 (X)
        # 4. z_2^\prime, c_2^\prime (X)
        code_loss = (
            (
                theta_mu[: batch // 2, : self.active_dim]
                - code[: batch // 2, : self.active_dim]
            )
            ** 2
        ).sum(dim=-1).mean() + (
            (
                theta_mu[batch // 2 :, : self.active_dim]
                - trans_code_by_g[batch // 2 :, : self.active_dim].detach()
            )
            ** 2
        ).sum(
            dim=-1
        ).mean()

        label_loss, label_acc = 0.0, 0.0
        idx_info = FACTOR_INFORM[self.dataset].to(device)
        if target is not None:
            for i in range(target.size(-1)):
                start, end = (
                    idx_info.sum() - idx_info[i:].sum(),
                    idx_info[: i + 1].sum(),
                )
                label_loss = (
                    label_loss
                    + F.cross_entropy(
                        label_logits[:, start:end], target[:, i], reduction="sum"
                    )
                    / batch
                )

            label_acc = (
                (class_label[:, : self.active_dim] == target).sum(dim=-1)
                // self.active_dim
            ).sum() / batch

        loss["obj"]["code_loss"] = code_loss
        loss["obj"]["canonical_loss"] = canonical_loss
        loss["obj"]["label_loss"] = label_loss
        loss["obj"]["label_acc"] = label_acc
        loss = (loss,) + (encoder_output,) + (decoder_output,) + (code,)

        return loss

    def batch_cos(self, canonical):
        norm = torch.norm(canonical, dim=-1, keepdim=True)
        output = torch.mm(canonical, canonical.transpose(-1, -2))
        norm = torch.mm(norm, norm.transpose(-1, -2))
        cos = torch.abs(output / norm)
        return cos

    def loss(self, input, outputs, loss_fn):
        result = {"elbo": {}, "obj": {}, "id": {}}

        reconsted_images = outputs[1][0]
        batch = reconsted_images.size(0) // 2
        z, mu, logvar = (
            outputs[0][0].squeeze(),
            outputs[0][1].squeeze(),
            outputs[0][2].squeeze(),
        )
        kld_err = torch.mean(
            -0.5
            * torch.sum(
                1 + logvar - mu**2 - logvar.exp(),
                dim=-1,
            )
        )
        reconst_err = loss_fn(reconsted_images[:batch], input) / batch
        dec_equiv_loss = loss_fn(reconsted_images[batch:], input) / batch

        result["obj"]["reconst"] = reconst_err  # .unsqueeze(0)
        result["obj"]["kld"] = kld_err  # .unsqueeze(0)
        result["obj"]["dec_equiv"] = dec_equiv_loss

        return result

    def enc_equiv_loss(self, tz, mu, logvar):
        loss = (
            (
                torch.log(math.sqrt(2 * math.pi) * torch.exp(0.5 * logvar))
                + 0.5 * ((tz.detach() - mu) / torch.exp(0.5 * logvar) + 1e-9) ** 2
            )
            .sum(dim=-1)
            .mean()
        )
        return loss

    def select_code(self, z):
        z = z.unsqueeze(-1)  # [B, D, 1]
        diff = torch.abs(z - self.ground)
        _, indices = torch.min(diff, dim=-1)
        output = z.squeeze() + self.ground[indices] - z.squeeze().detach()

        return output

    def real_to_theta(self, M):
        acos_func = Acosine.apply
        r, c = self.real_to_comp(M)
        theta = acos_func(r) - math.pi
        theta = torch.where(c >= 0, theta, -theta)  # * scale
        return theta

    # conformal map from real number to complex number
    def real_to_comp(self, M):
        real = (M**2 - 1) / (M**2 + 1)
        comp = -2 * M / (M**2 + 1)
        return real, comp

    # conformal map from complex number to real number
    def comp_to_real(self, real, comp):
        output = -2 * comp / ((real - 1) ** 2 + comp**2)
        return output

    def group_action(self, theta, sym, scale):
        batch = theta.size(0)
        dtheta = scale * sym * 2 * math.pi / self.n  # * grid_scale
        output = theta + dtheta
        output = torch.where(output > math.pi, output - 2 * math.pi, output)
        output = torch.where(output < -math.pi, output + 2 * math.pi, output)

        new_output = torch.cat([output[batch // 2 :], output[: batch // 2]], dim=0)
        return new_output

    def random_disent_group_action(self, theta):
        row, col = theta.size()
        mask = torch.zeros_like(theta)
        selected_col = random.randint(0, self.active_dim - 1)
        mask[:, selected_col] = 1.0

        changed_theta = mask * theta
        shuffled_idx = torch.randperm(row)
        changed_theta = changed_theta[shuffled_idx]

        theta = torch.where(mask == 0, theta, changed_theta)
        return theta, selected_col, shuffled_idx

    def disent_pseudo_label(self, label, col, idx):
        new_label = label.clone()
        idx_info = FACTOR_INFORM[self.dataset].to(device)
        start, end = idx_info.sum() - idx_info[col:].sum(), idx_info[: col + 1].sum()
        new_label[:, start:end] = label[idx, start:end]
        return new_label

    def init_weights(self):
        for n, p in self.named_parameters():
            if "ground" in n or "grid" in n or "label_code" in n:
                continue
            else:
                if p.data.ndimension() >= 2:
                    nn.init.xavier_uniform_(p.data)
                else:
                    nn.init.zeros_(p.data)

    def freeze(self):
        for n, p in self.named_parameters():
            if "ground" in n or "label_code" in n:
                p.requires_grad = False
            else:
                p.requires_grad = True


# manuall acosine module to prevent gradient explosion
class Acosine(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.acos(input)

    @staticmethod
    def backward(ctx, grad_output):
        (x,) = ctx.saved_tensors
        acos_grad = -1 / (torch.sqrt(1 - x**2) + 1e-9)
        # acos_grad = torch.where(acos_grad > 1e2, acos_grad * 1e-1, acos_grad)
        grad_output = grad_output * acos_grad
        return grad_output
