import pdb
import math
import torch
import torch.nn as nn

from dl.models.layers.base import CNNLayer
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class dSprites_Classifier(nn.Module):
    def __init__(self, config):
        super(dSprites_Classifier, self).__init__()
        modules = []
        self.latent_dim = config.latent_dim
        self.hidden_states = config.hidden_states
        self.num_sampling = config.num_sampling
        self.n = config.nth_root
        self.label_code = nn.Parameter(
            torch.linspace(start=0, end=99, steps=100, requires_grad=False)
        )
        # Design Encoder Factor-VAE ref
        modules.append(CNNLayer(in_channels=1, out_channels=32))
        modules.append(CNNLayer(in_channels=32, out_channels=32))
        modules.append(CNNLayer(in_channels=32, out_channels=64))
        modules.append(CNNLayer(in_channels=64, out_channels=64))
        self.hidden_layers = nn.ModuleList(modules)
        self.dense = nn.Linear(config.dense_dim[0], config.dense_dim[1])
        # self.classifier = nn.Linear(config.dense_dim[1], self.latent_dim)
        self.classifier0 = nn.Linear(config.dense_dim[1], 3)
        self.classifier1 = nn.Linear(config.dense_dim[1], 6)
        self.classifier2 = nn.Linear(config.dense_dim[1], 40)
        self.classifier3 = nn.Linear(config.dense_dim[1], 32)
        self.classifier4 = nn.Linear(config.dense_dim[1], 32)
        self.classifier5 = nn.Linear(config.dense_dim[1], 100)
        self.classifier6 = nn.Linear(config.dense_dim[1], 100)
        self.classifier7 = nn.Linear(config.dense_dim[1], 100)
        self.classifier8 = nn.Linear(config.dense_dim[1], 100)
        self.classifier9 = nn.Linear(config.dense_dim[1], 100)
        self.active = nn.Sigmoid()
        if config.dataset == "dsprites":
            self.active_dim = 5

    def forward(self, input):
        # self.iter += 1.0
        # [|B|, |F|]
        batch = input.size(0)
        all_hidden_states = ()
        output = input
        if self.hidden_states:
            all_hidden_states = all_hidden_states + (output,)
        for i, hidden_layer in enumerate(self.hidden_layers):
            output = hidden_layer(output)
            if self.hidden_states:
                all_hidden_states = all_hidden_states + (output,)
        # output = torch.flatten(output, start_dim=1)
        output = self.dense(output.contiguous().view(batch, -1))
        # output = self.classifier(output)
        label0_logit = self.classifier0(output)  # [B, 100]
        label1_logit = self.classifier1(output)  # [B, 100]
        label2_logit = self.classifier2(output)  # [B, 100]
        label3_logit = self.classifier3(output)  # [B, 100]
        label4_logit = self.classifier4(output)  # [B, 100]
        label5_logit = self.classifier5(output)  # [B, 100]
        label6_logit = self.classifier6(output)  # [B, 100]
        label7_logit = self.classifier7(output)  # [B, 100]
        label8_logit = self.classifier8(output)  # [B, 100]
        label9_logit = self.classifier9(output)  # [B, 100]

        label0_glogit = F.gumbel_softmax(label0_logit, hard=True)  # [B, 100]
        label1_glogit = F.gumbel_softmax(label1_logit, hard=True)  # [B, 100]
        label2_glogit = F.gumbel_softmax(label2_logit, hard=True)  # [B, 100]
        label3_glogit = F.gumbel_softmax(label3_logit, hard=True)  # [B, 100]
        label4_glogit = F.gumbel_softmax(label4_logit, hard=True)  # [B, 100]
        label5_glogit = F.gumbel_softmax(label5_logit, tau=1000, hard=True)  # [B, 100]
        label6_glogit = F.gumbel_softmax(label6_logit, tau=1000, hard=True)  # [B, 100]
        label7_glogit = F.gumbel_softmax(label7_logit, tau=1000, hard=True)  # [B, 100
        label8_glogit = F.gumbel_softmax(label8_logit, tau=1000, hard=True)  # [B, 100]
        label9_glogit = F.gumbel_softmax(label9_logit, tau=1000, hard=True)  # [B, 100]

        label0 = (self.label_code[:3] * label0_glogit).sum(dim=-1)  # [B]
        label1 = (self.label_code[:6] * label1_glogit).sum(dim=-1)
        label2 = (self.label_code[:40] * label2_glogit).sum(dim=-1)
        label3 = (self.label_code[:32] * label3_glogit).sum(dim=-1)
        label4 = (self.label_code[:32] * label4_glogit).sum(dim=-1)
        label5 = (self.label_code * label5_glogit).sum(dim=-1)
        label6 = (self.label_code * label6_glogit).sum(dim=-1)
        label7 = (self.label_code * label7_glogit).sum(dim=-1)
        label8 = (self.label_code * label8_glogit).sum(dim=-1)
        label9 = (self.label_code * label9_glogit).sum(dim=-1)

        # for learning
        label_logits = torch.cat(
            [
                label0_logit,
                label1_logit,
                label2_logit,
                label3_logit,
                label4_logit,
                label5_logit,
                label6_logit,
                label7_logit,
                label8_logit,
                label9_logit,
            ],
            dim=-1,
        )  # [B, 100*100]
        label_glogits = torch.cat(
            [
                label0_glogit,
                label1_glogit,
                label2_glogit,
                label3_glogit,
                label4_glogit,
                label5_glogit,
                label6_glogit,
                label7_glogit,
                label8_glogit,
                label9_glogit,
            ],
            dim=-1,
        )  # [B, 100*100]
        # as a symmetries
        label = torch.stack(
            [
                label0,
                label1,
                label2,
                label3,
                label4,
                label5,
                label6,
                label7,
                label8,
                label9,
            ],
            dim=0,
        ).transpose(
            -1, -2
        )  # [D, B] -> [B, D]

        return label, label_logits[:, :113]  # , label_glogits[:, :113]


class Shapes3D_Classifier(nn.Module):
    def __init__(self, config):
        super(Shapes3D_Classifier, self).__init__()
        modules = []
        self.latent_dim = config.latent_dim
        self.hidden_states = config.hidden_states
        self.num_sampling = config.num_sampling
        self.label_code = nn.Parameter(
            torch.linspace(start=0, end=99, steps=100, requires_grad=False)
        )
        # Design Encoder Factor-VAE ref
        modules.append(CNNLayer(in_channels=3, out_channels=32))
        modules.append(CNNLayer(in_channels=32, out_channels=32))
        modules.append(CNNLayer(in_channels=32, out_channels=64))
        modules.append(CNNLayer(in_channels=64, out_channels=64))
        self.hidden_layers = nn.ModuleList(modules)
        self.dense = nn.Linear(config.dense_dim[0], config.dense_dim[1])
        # self.classifier = nn.Linear(config.dense_dim[1], self.latent_dim)
        self.classifier0 = nn.Linear(config.dense_dim[1], 10)
        self.classifier1 = nn.Linear(config.dense_dim[1], 10)
        self.classifier2 = nn.Linear(config.dense_dim[1], 10)
        self.classifier3 = nn.Linear(config.dense_dim[1], 8)
        self.classifier4 = nn.Linear(config.dense_dim[1], 4)
        self.classifier5 = nn.Linear(config.dense_dim[1], 15)
        self.active = nn.Sigmoid()
        if config.dataset == "shapes3d":
            self.active_dim = 6

    def forward(self, input):
        # self.iter += 1.0
        # [|B|, |F|]
        batch = input.size(0)
        all_hidden_states = ()
        output = input
        if self.hidden_states:
            all_hidden_states = all_hidden_states + (output,)
        for i, hidden_layer in enumerate(self.hidden_layers):
            output = hidden_layer(output)
            if self.hidden_states:
                all_hidden_states = all_hidden_states + (output,)
        # output = torch.flatten(output, start_dim=1)
        output = self.dense(output.contiguous().view(batch, -1))
        # output = self.classifier(output)
        label0_logit = self.classifier0(output)  # [B, 100]
        label1_logit = self.classifier1(output)  # [B, 100]
        label2_logit = self.classifier2(output)  # [B, 100]
        label3_logit = self.classifier3(output)  # [B, 100]
        label4_logit = self.classifier4(output)  # [B, 100]
        label5_logit = self.classifier5(output)  # [B, 100]

        label0_glogit = F.gumbel_softmax(label0_logit, hard=True)  # [B, 100]
        label1_glogit = F.gumbel_softmax(label1_logit, hard=True)  # [B, 100]
        label2_glogit = F.gumbel_softmax(label2_logit, hard=True)  # [B, 100]
        label3_glogit = F.gumbel_softmax(label3_logit, hard=True)  # [B, 100]
        label4_glogit = F.gumbel_softmax(label4_logit, hard=True)  # [B, 100]
        label5_glogit = F.gumbel_softmax(label5_logit, hard=True)  # [B, 100]

        label0 = (self.label_code[:10] * label0_glogit).sum(dim=-1)  # [B]
        label1 = (self.label_code[:10] * label1_glogit).sum(dim=-1)
        label2 = (self.label_code[:10] * label2_glogit).sum(dim=-1)
        label3 = (self.label_code[:8] * label3_glogit).sum(dim=-1)
        label4 = (self.label_code[:4] * label4_glogit).sum(dim=-1)
        label5 = (self.label_code[:15] * label5_glogit).sum(dim=-1)

        # for learning
        label_logits = torch.cat(
            [
                label0_logit,
                label1_logit,
                label2_logit,
                label3_logit,
                label4_logit,
                label5_logit,
            ],
            dim=-1,
        )  # [B, 100*100]
        label_glogits = torch.cat(
            [
                label0_glogit,
                label1_glogit,
                label2_glogit,
                label3_glogit,
                label4_glogit,
                label5_glogit,
            ],
            dim=-1,
        )  # [B, 100*100]
        # as a symmetries
        label = torch.stack(
            [
                label0,
                label1,
                label2,
                label3,
                label4,
                label5,
            ],
            dim=0,
        ).transpose(
            -1, -2
        )  # [D, B] -> [B, D]

        return label, label_logits  # , label_glogits


class MPI3D_Complex_Classifier(nn.Module):
    def __init__(self, config):
        super(MPI3D_Complex_Classifier, self).__init__()
        modules = []
        self.latent_dim = config.latent_dim
        self.hidden_states = config.hidden_states
        self.num_sampling = config.num_sampling
        self.label_code = nn.Parameter(
            torch.linspace(start=0, end=99, steps=100, requires_grad=False)
        )

        # Design Encoder Factor-VAE ref
        modules.append(CNNLayer(in_channels=3, out_channels=32))
        modules.append(CNNLayer(in_channels=32, out_channels=32))
        modules.append(CNNLayer(in_channels=32, out_channels=64))
        modules.append(CNNLayer(in_channels=64, out_channels=64))
        self.hidden_layers = nn.ModuleList(modules)
        self.dense = nn.Linear(config.dense_dim[0], config.dense_dim[1])
        # self.classifier = nn.Linear(config.dense_dim[1], self.latent_dim)
        self.classifier0 = nn.Linear(config.dense_dim[1], 4)
        self.classifier1 = nn.Linear(config.dense_dim[1], 4)
        self.classifier2 = nn.Linear(config.dense_dim[1], 2)
        self.classifier3 = nn.Linear(config.dense_dim[1], 3)
        self.classifier4 = nn.Linear(config.dense_dim[1], 3)
        self.classifier5 = nn.Linear(config.dense_dim[1], 40)
        self.classifier6 = nn.Linear(config.dense_dim[1], 40)
        self.classifier7 = nn.Linear(config.dense_dim[1], 100)
        self.classifier8 = nn.Linear(config.dense_dim[1], 100)
        self.classifier9 = nn.Linear(config.dense_dim[1], 100)
        self.active = nn.Sigmoid()
        if config.dataset == "dsprites":
            self.active_dim = 5

    def forward(self, input):
        # self.iter += 1.0
        # [|B|, |F|]
        batch = input.size(0)
        all_hidden_states = ()
        output = input
        if self.hidden_states:
            all_hidden_states = all_hidden_states + (output,)
        for i, hidden_layer in enumerate(self.hidden_layers):
            output = hidden_layer(output)
            if self.hidden_states:
                all_hidden_states = all_hidden_states + (output,)
        # output = torch.flatten(output, start_dim=1)
        output = self.dense(output.contiguous().view(batch, -1))
        # output = self.classifier(output)
        label0_logit = self.classifier0(output)  # [B, 100]
        label1_logit = self.classifier1(output)  # [B, 100]
        label2_logit = self.classifier2(output)  # [B, 100]
        label3_logit = self.classifier3(output)  # [B, 100]
        label4_logit = self.classifier4(output)  # [B, 100]
        label5_logit = self.classifier5(output)  # [B, 100]
        label6_logit = self.classifier6(output)  # [B, 100]
        label7_logit = self.classifier7(output)  # [B, 100]
        label8_logit = self.classifier8(output)  # [B, 100]
        label9_logit = self.classifier9(output)  # [B, 100]

        label0_glogit = F.gumbel_softmax(label0_logit, hard=True)  # [B, 100]
        label1_glogit = F.gumbel_softmax(label1_logit, hard=True)  # [B, 100]
        label2_glogit = F.gumbel_softmax(label2_logit, hard=True)  # [B, 100]
        label3_glogit = F.gumbel_softmax(label3_logit, hard=True)  # [B, 100]
        label4_glogit = F.gumbel_softmax(label4_logit, hard=True)  # [B, 100]
        label5_glogit = F.gumbel_softmax(label5_logit, hard=True)  # [B, 100]
        label6_glogit = F.gumbel_softmax(label6_logit, hard=True)  # [B, 100]
        label7_glogit = F.gumbel_softmax(label7_logit, tau=1000, hard=True)  # [B, 100
        label8_glogit = F.gumbel_softmax(label8_logit, tau=1000, hard=True)  # [B, 100]
        label9_glogit = F.gumbel_softmax(label9_logit, tau=1000, hard=True)  # [B, 100]

        label0 = (self.label_code[:4] * label0_glogit).sum(dim=-1)  # [B]
        label1 = (self.label_code[:4] * label1_glogit).sum(dim=-1)
        label2 = (self.label_code[:2] * label2_glogit).sum(dim=-1)
        label3 = (self.label_code[:3] * label3_glogit).sum(dim=-1)
        label4 = (self.label_code[:3] * label4_glogit).sum(dim=-1)
        label5 = (self.label_code[:40] * label5_glogit).sum(dim=-1)
        label6 = (self.label_code[:40] * label6_glogit).sum(dim=-1)
        label7 = (self.label_code * label7_glogit).sum(dim=-1)
        label8 = (self.label_code * label8_glogit).sum(dim=-1)
        label9 = (self.label_code * label9_glogit).sum(dim=-1)

        # for learning
        label_logits = torch.cat(
            [
                label0_logit,
                label1_logit,
                label2_logit,
                label3_logit,
                label4_logit,
                label5_logit,
                label6_logit,
                label7_logit,
                label8_logit,
                label9_logit,
            ],
            dim=-1,
        )  # [B, 100*100]
        label_glogits = torch.cat(
            [
                label0_glogit,
                label1_glogit,
                label2_glogit,
                label3_glogit,
                label4_glogit,
                label5_glogit,
                label6_glogit,
                label7_glogit,
                label8_glogit,
                label9_glogit,
            ],
            dim=-1,
        )  # [B, 100*100]
        # as a symmetries
        label = torch.stack(
            [
                label0,
                label1,
                label2,
                label3,
                label4,
                label5,
                label6,
                label7,
                label8,
                label9,
            ],
            dim=0,
        ).transpose(
            -1, -2
        )  # [D, B] -> [B, D]

        return label, label_logits[:, :96]  # , label_glogits[:, :113]


class MPI3D_Classifier(MPI3D_Complex_Classifier):
    def __init__(self, config):
        super(MPI3D_Classifier, self).__init__(config)
        self.classifier0 = nn.Linear(config.dense_dim[1], 6)
        self.classifier1 = nn.Linear(config.dense_dim[1], 6)

    def forward(self, input):
        # self.iter += 1.0
        # [|B|, |F|]
        batch = input.size(0)
        all_hidden_states = ()
        output = input
        if self.hidden_states:
            all_hidden_states = all_hidden_states + (output,)
        for i, hidden_layer in enumerate(self.hidden_layers):
            output = hidden_layer(output)
            if self.hidden_states:
                all_hidden_states = all_hidden_states + (output,)
        # output = torch.flatten(output, start_dim=1)
        output = self.dense(output.contiguous().view(batch, -1))
        # output = self.classifier(output)
        label0_logit = self.classifier0(output)  # [B, 100]
        label1_logit = self.classifier1(output)  # [B, 100]
        label2_logit = self.classifier2(output)  # [B, 100]
        label3_logit = self.classifier3(output)  # [B, 100]
        label4_logit = self.classifier4(output)  # [B, 100]
        label5_logit = self.classifier5(output)  # [B, 100]
        label6_logit = self.classifier6(output)  # [B, 100]
        label7_logit = self.classifier7(output)  # [B, 100]
        label8_logit = self.classifier8(output)  # [B, 100]
        label9_logit = self.classifier9(output)  # [B, 100]

        label0_glogit = F.gumbel_softmax(label0_logit, hard=True)  # [B, 100]
        label1_glogit = F.gumbel_softmax(label1_logit, hard=True)  # [B, 100]
        label2_glogit = F.gumbel_softmax(label2_logit, hard=True)  # [B, 100]
        label3_glogit = F.gumbel_softmax(label3_logit, hard=True)  # [B, 100]
        label4_glogit = F.gumbel_softmax(label4_logit, hard=True)  # [B, 100]
        label5_glogit = F.gumbel_softmax(label5_logit, hard=True)  # [B, 100]
        label6_glogit = F.gumbel_softmax(label6_logit, hard=True)  # [B, 100]
        label7_glogit = F.gumbel_softmax(label7_logit, tau=1000, hard=True)  # [B, 100
        label8_glogit = F.gumbel_softmax(label8_logit, tau=1000, hard=True)  # [B, 100]
        label9_glogit = F.gumbel_softmax(label9_logit, tau=1000, hard=True)  # [B, 100]

        label0 = (self.label_code[:6] * label0_glogit).sum(dim=-1)  # [B]
        label1 = (self.label_code[:6] * label1_glogit).sum(dim=-1)
        label2 = (self.label_code[:2] * label2_glogit).sum(dim=-1)
        label3 = (self.label_code[:3] * label3_glogit).sum(dim=-1)
        label4 = (self.label_code[:3] * label4_glogit).sum(dim=-1)
        label5 = (self.label_code[:40] * label5_glogit).sum(dim=-1)
        label6 = (self.label_code[:40] * label6_glogit).sum(dim=-1)
        label7 = (self.label_code * label7_glogit).sum(dim=-1)
        label8 = (self.label_code * label8_glogit).sum(dim=-1)
        label9 = (self.label_code * label9_glogit).sum(dim=-1)

        # for learning
        label_logits = torch.cat(
            [
                label0_logit,
                label1_logit,
                label2_logit,
                label3_logit,
                label4_logit,
                label5_logit,
                label6_logit,
                label7_logit,
                label8_logit,
                label9_logit,
            ],
            dim=-1,
        )  # [B, 100*100]
        label_glogits = torch.cat(
            [
                label0_glogit,
                label1_glogit,
                label2_glogit,
                label3_glogit,
                label4_glogit,
                label5_glogit,
                label6_glogit,
                label7_glogit,
                label8_glogit,
                label9_glogit,
            ],
            dim=-1,
        )  # [B, 100*100]
        # as a symmetries
        label = torch.stack(
            [
                label0,
                label1,
                label2,
                label3,
                label4,
                label5,
                label6,
                label7,
                label8,
                label9,
            ],
            dim=0,
        ).transpose(
            -1, -2
        )  # [D, B] -> [B, D]

        return label, label_logits[:, :100]  # , label_glogits[:, :113]


def target_change(target, factor_inform, repeat):
    transformed_target = torch.zeros_like(target)
    for i in range(len(factor_inform)):
        target_m = target >= factor_inform[i] / 2
        transformed_target[:, i * repeat: (i + 1) * repeat] = torch.where(
            target_m == 1.0, target - factor_inform[i], target
        )[:, i * repeat: (i + 1) * repeat]
    return transformed_target


class Unsuper_Encoder_dsprites(nn.Module):
    def __init__(self, config):
        super(Unsuper_Encoder_dsprites, self).__init__()
        modules, p_modules = [], {}
        self.latent_dim = config.latent_dim
        self.hidden_states = config.hidden_states
        self.num_sampling = config.num_sampling

        self.prior_list = config.prior_list
        # Design Encoder Factor-VAE ref
        modules.append(CNNLayer(in_channels=1, out_channels=32))
        modules.append(CNNLayer(in_channels=32, out_channels=32))
        modules.append(CNNLayer(in_channels=32, out_channels=64))
        modules.append(CNNLayer(in_channels=64, out_channels=64))
        self.hidden_layers = nn.ModuleList(modules)

        self.dense = nn.Linear(config.dense_dim[0], config.dense_dim[1])
        self.switch = nn.Linear(config.dense_dim[1], self.latent_dim + 1)
        self.mu = nn.Linear(config.dense_dim[1], self.latent_dim)
        self.logvar = nn.Linear(config.dense_dim[1], self.latent_dim)

        for i in range(len(self.prior_list)):
            p_modules['p' + str(i)] = nn.Linear(config.dense_dim[1], self.prior_list[i])

        self.prior_dict = nn.ModuleDict(p_modules)
        self.iter = 0.0

    def forward(self, input):

        all_hidden_states = ()

        all_soft_states = ()
        all_hard_states = ()
        steps = []

        output = input
        if self.hidden_states:
            all_hidden_states = all_hidden_states + (output,)
        for i, hidden_layer in enumerate(self.hidden_layers):
            output = hidden_layer(output)
            if self.hidden_states:
                all_hidden_states = all_hidden_states + (output,)
        # output = torch.flatten(output, start_dim=1)
        output = self.dense(
            output.contiguous().view(output.size(0), -1)
        )  # 4-D tensor: [Batch, *] --> 2-D tensor: [Batch, latent dim]

        mu = self.mu(output)  # [Batch, latent dim]
        logvar = self.logvar(output)  # [Batch, latent dim]
        z = self.reparameterization(mu, logvar)

        switch = output - output.unsqueeze(1)  # [|B|, |B|, |D|]
        switch = switch.view(-1, switch.size(-1))  # [|B| |B|, |D|]
        switch = F.softmax(self.switch(switch), dim=-1)  # [|B| |B|, latent_dim + 1]
        # if self.iter % 1000 == 0:
        #     pdb.set_trace()
        switch = torch.where(switch > (1 / (self.latent_dim + 1)),
                             1.0 + switch - switch.detach(),
                             0.0 + switch - switch.detach())

        num_dim = switch.sum(dim=-1)  # [|B| |B|]
        prior_kld = num_dim * math.log((1 / self.latent_dim + 1e-9) / (1e-9)) / self.latent_dim

        tau = 5 - 4 * self.iter * 3e-5
        for key, value in self.prior_dict.items():
            soft_state = F.gumbel_softmax(value(output), tau=tau, hard=False)
            hard_state = F.gumbel_softmax(value(output), tau=tau, hard=True)
            all_soft_states = all_soft_states + (soft_state,)
            all_hard_states = all_hard_states + (hard_state,)

        for hard_states in all_hard_states:
            step = hard_states * torch.arange(hard_states.size(-1)).to(device)
            step = torch.sum(step, dim=-1, keepdim=True)  # [|B|, 1]
            steps.append(step)

        steps = torch.cat(steps, dim=-1)  # [|B|, |D|]
        # pdb.set_trace()
        if self.training:
            self.iter += 1
        return z, steps, all_soft_states, all_hard_states, prior_kld

    def reparameterization(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(logvar)
        z = mu + std * eps
        return z















