import math
import torch
import torch.nn as nn
from torch.nn.functional import gumbel_softmax, softmax

from .cosine_temperature import CosineTemperature


class Generator(nn.Module):
    def __init__(self, num_blocks=9, num_ops=7, num_pex=3, num_pey=5, num_rf=5, num_df=3, use_gumbel=True, gumbel_tau=5,
                output_one_hot=True, total_epoch=120):
        super(Generator, self).__init__()
        assert use_gumbel or output_one_hot
        self.use_gumbel = use_gumbel
        self.temperature = CosineTemperature(eta_max=gumbel_tau, eta_min=0.5, total_epoch=total_epoch)
        self.temperature.update_tau(0)
        self.output_one_hot = output_one_hot

        self.layers = nn.ModuleList()
        self.layers.append(
            nn.Sequential(
                nn.Linear(num_blocks * num_ops, 128, bias=False),
                nn.BatchNorm1d(128),
                nn.Sigmoid(),
            )
        )
        for i in range(3):
            self.layers.append(
                Block(128)
            )

        self.fc_pex = nn.Sequential(
            nn.Linear(128, num_pex, bias=False),
            nn.BatchNorm1d(num_pex),
        )
        self.fc_pey = nn.Sequential(
            nn.Linear(128, num_pey, bias=False),
            nn.BatchNorm1d(num_pey),
        )
        self.fc_rf = nn.Sequential(
            nn.Linear(128, num_rf, bias=False),
            nn.BatchNorm1d(num_rf),
        )
        self.fc_df = nn.Sequential(
            nn.Linear(128, num_df, bias=False),
            nn.BatchNorm1d(num_df),
        )


    def forward(self, x, eval_gumbel=True):
        for layer in self.layers:
            x = layer(x)

        pex = self.fc_pex(x)
        pey = self.fc_pey(x)
        rf = self.fc_rf(x)
        df = self.fc_df(x)

        if self.training:
            if self.use_gumbel:
                pex = gumbel_softmax(pex, tau=self.temperature.tau, hard=False)
                pey = gumbel_softmax(pey, tau=self.temperature.tau, hard=False)
                rf = gumbel_softmax(rf, tau=self.temperature.tau, hard=False)
                df = gumbel_softmax(df, tau=self.temperature.tau, hard=False)
            else:
                if self.output_one_hot:
                    pex_index = softmax(pex, dim=-1).argmax(dim=-1, keepdim=True)
                    pey_index = softmax(pey, dim=-1).argmax(dim=-1, keepdim=True)
                    rf_index = softmax(rf, dim=-1).argmax(dim=-1, keepdim=True)
                    df_index = softmax(df, dim=-1).argmax(dim=-1, keepdim=True)
                    pex = torch.zeros(pex.shape).to(pex_index.device).scatter_(dim=1, index=pex_index, value=1)
                    pey = torch.zeros(pey.shape).to(pey_index.device).scatter_(dim=1, index=pey_index, value=1)
                    rf = torch.zeros(rf.shape).to(rf_index.device).scatter_(dim=1, index=rf_index, value=1)
                    df = torch.zeros(df.shape).to(df_index.device).scatter_(dim=1, index=df_index, value=1)
                else:
                    pex = softmax(pex, dim=-1)
                    pey = softmax(pey, dim=-1)
                    rf = softmax(rf, dim=-1)
                    df = softmax(df, dim=-1)
            return torch.cat([pex, pey, rf, df], dim=-1)
        else:
            if eval_gumbel:
                pex = gumbel_softmax(pex, tau=self.temperature.tau, hard=self.output_one_hot)
                pey = gumbel_softmax(pey, tau=self.temperature.tau, hard=self.output_one_hot)
                rf = gumbel_softmax(rf, tau=self.temperature.tau, hard=self.output_one_hot)
                df = gumbel_softmax(df, tau=self.temperature.tau, hard=self.output_one_hot) 
            else:
                # if self.output_one_hot:
                pex_index = softmax(pex, dim=-1).argmax(dim=-1, keepdim=True)
                pey_index = softmax(pey, dim=-1).argmax(dim=-1, keepdim=True)
                rf_index = softmax(rf, dim=-1).argmax(dim=-1, keepdim=True)
                df_index = softmax(df, dim=-1).argmax(dim=-1, keepdim=True)
                pex = torch.zeros(pex.shape).to(pex_index.device).scatter_(dim=1, index=pex_index, value=1)
                pey = torch.zeros(pey.shape).to(pey_index.device).scatter_(dim=1, index=pey_index, value=1)
                rf = torch.zeros(rf.shape).to(rf_index.device).scatter_(dim=1, index=rf_index, value=1)
                df = torch.zeros(df.shape).to(df_index.device).scatter_(dim=1, index=df_index, value=1)
                # else:
                #     pex = softmax(pex, dim=-1)
                #     pey = softmax(pey, dim=-1)
                #     rf = softmax(rf, dim=-1)
                #     df = softmax(df, dim=-1)

        return torch.cat([pex, pey, rf, df], dim=-1)


class Block(nn.Module):
    def __init__(self, num_features):
        super(Block, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(num_features, num_features, bias=False),
            nn.BatchNorm1d(num_features),
        )

    def forward(self, x):
        residual = x
        x = self.layer(x)
        x = torch.sigmoid(x).clone()
        x += residual
        return x

