# -*- encoding: utf-8 -*-

import torch
import torch.nn as nn
from asteroid_filterbanks import Encoder, ParamSincFB

from .RawNetBasicBlock import PreEmphasis
from hyperparameters import args
import math
import torch.nn.functional as F

class AFMS(nn.Module):
    """
    Alpha-Feature map scaling, added to the output of each residual block[1,2].

    Reference:
    [1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf
    [2] AMFS    : https://www.koreascience.or.kr/article/JAKO202029757857763.page
    """

    def __init__(self, nb_dim: int) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
        self.fc = nn.Linear(nb_dim, nb_dim)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
        y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)

        x = x + self.alpha
        x = x * y
        return x

class Bottle2neck(nn.Module):
    def __init__(
        self,
        inplanes,
        planes,
        kernel_size=None,
        dilation=None,
        scale=4,
        pool=False,
        layernorm_dim=9575,
    ):

        super().__init__()

        width = int(math.floor(planes / scale))

        self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
        self.bn1 = nn.LayerNorm(normalized_shape=layernorm_dim)

        self.nums = scale - 1

        convs = []
        bns = []

        num_pad = math.floor(kernel_size / 2) * dilation

        for i in range(self.nums):
            convs.append(
                nn.Conv1d(
                    width,
                    width,
                    kernel_size=kernel_size,
                    dilation=dilation,
                    padding=num_pad,
                )
            )
            bns.append(nn.LayerNorm(normalized_shape=layernorm_dim))

        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)

        self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
        self.bn3 = nn.LayerNorm(normalized_shape=layernorm_dim)

        self.relu = nn.ReLU()

        self.width = width

        self.mp = nn.MaxPool1d(pool) if pool else False
        self.afms = AFMS(planes)

        if inplanes != planes:  # if change in number of filters
            self.residual = nn.Sequential(
                nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
            )
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        residual = self.residual(x)

        out = self.conv1(x)
        out = self.relu(out)
        out = self.bn1(out)

        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            sp = self.relu(sp)
            sp = self.bns[i](sp)
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)

        out = torch.cat((out, spx[self.nums]), 1)

        out = self.conv3(out)
        out = self.relu(out)
        out = self.bn3(out)

        out += residual
        if self.mp:
            out = self.mp(out)
        out = self.afms(out)

        return out



class RawNet3_detect_sparse_experts(nn.Module):
    def __init__(self, block, model_scale, context, summed, C=1024, n_cls=2, **kwargs):
        super().__init__()

        nOut = kwargs["nOut"]

        self.context = context
        self.encoder_type = kwargs["encoder_type"]
        self.log_sinc = kwargs["log_sinc"]
        self.norm_sinc = kwargs["norm_sinc"]
        self.out_bn = kwargs["out_bn"]
        self.summed = summed

        self.preprocess = nn.Sequential(
            PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
        )

        self.gate_conv1 = nn.Sequential(nn.Linear(96000, args.num_experts), nn.Softmax(dim=-1))
        self.conv1 = nn.ModuleList([Encoder(
            ParamSincFB(
                C // 4,
                251,
                stride=kwargs["sinc_stride"],
            )
        ) for _ in range(args.num_experts)])
        self.relu = nn.ReLU()
        # self.bn1 = nn.BatchNorm1d(C // 4)

        self.gate1 = nn.Sequential(nn.Linear(9575, args.num_experts), nn.Softmax(dim=-1))
        self.layer1 = nn.ModuleList([block(
            C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5, layernorm_dim = 9575,
        ) for _ in range(args.num_experts)])
        self.gate2 = nn.Sequential(nn.Linear(1915, args.num_experts), nn.Softmax(dim=-1))
        self.layer2 = nn.ModuleList([block(
            C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3, layernorm_dim = 1915,
        ) for _ in range(args.num_experts)])
        self.gate3 = nn.Sequential(nn.Linear(638, args.num_experts), nn.Softmax(dim=-1))
        self.layer3 = nn.ModuleList([block(
            C, C, kernel_size=3, dilation=4, scale=model_scale, layernorm_dim =638) for _ in range(args.num_experts)])

        self.gate4 = nn.Sequential(nn.Linear(638, args.num_experts), nn.Softmax(dim=-1))
        self.layer4 = nn.ModuleList([nn.Conv1d(3 * C, 1536, kernel_size=1) for _ in range(args.num_experts)])

        if self.context:
            attn_input = 1536 * 3
        else:
            attn_input = 1536
        print("self.encoder_type", self.encoder_type)
        if self.encoder_type == "ECA":
            attn_output = 1536
        elif self.encoder_type == "ASP":
            attn_output = 1
        else:
            raise ValueError("Undefined encoder")

        self.gate_attention = nn.Sequential(nn.Linear(638, args.num_experts), nn.Softmax(dim=-1))
        self.attention = nn.ModuleList([nn.Sequential(
            nn.Conv1d(attn_input, 128, kernel_size=1),
            nn.ReLU(),
            nn.LayerNorm(normalized_shape=638),
            nn.Conv1d(128, attn_output, kernel_size=1),
            nn.Softmax(dim=2),
        ) for _ in range(args.num_experts)])

        self.bn5 = nn.LayerNorm(normalized_shape=3072)

        self.gate_output = nn.Sequential(
            nn.Linear(3072, args.num_experts),
            nn.Softmax(dim=1)
        )

        self.mlp_experts = nn.ModuleList([nn.Sequential(
            nn.Linear(3072, 3072 * 2),
            nn.ReLU(),
            nn.Linear(3072 * 2, 3072),
            nn.ReLU(),
            nn.LayerNorm(normalized_shape=3072),
            nn.Linear(3072, nOut),
            nn.LayerNorm(normalized_shape=256),
            nn.ReLU(),
            nn.Linear(nOut, 2),
        ) for _ in range(args.num_experts)])
        # self.fc6 = nn.Linear(3072, nOut)
        # self.bn6 = nn.BatchNorm1d(nOut)

        self.mp3 = nn.MaxPool1d(3)

        self.cls = nn.Linear(nOut, n_cls)

    def forward(self, x, eval=False):
        """
        :param x: input mini-batch (bs, samp)
        """

        batch_size = x.size(0)
        with torch.cuda.amp.autocast(enabled=False):
            x = self.preprocess(x)

            # if eval == False:
            #     gate_conv1 = self.gate_conv1(x)
            #     gate_conv1 = torch.mean(gate_conv1, dim=1)
            #     x = torch.stack([self.conv1[i](x) for i in range(args.num_experts)])
            #     x = x.permute(1, 0, 2, 3)
            #     x = torch.einsum('be,becd->bcd', gate_conv1, x)
            #     x = torch.abs(x)
            # else:
            gate_conv1 = self.gate_conv1(x)
            gate_conv1 = torch.mean(gate_conv1, dim=1)
            max_value, idx = torch.max(gate_conv1, dim=1)
            x = torch.stack([self.conv1[idx[i]](x[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)
            x = torch.abs(x)

            # x = torch.abs(self.conv1[0](x))
            if self.log_sinc:
                x = torch.log(x + 1e-6)
            if self.norm_sinc == "mean":
                x = x - torch.mean(x, dim=-1, keepdim=True)
            elif self.norm_sinc == "mean_std":
                m = torch.mean(x, dim=-1, keepdim=True)
                s = torch.std(x, dim=-1, keepdim=True)
                s[s < 0.001] = 0.001
                x = (x - m) / s
        #
        # if self.summed:
        #     x1 = self.layer1[0](x)
        #     x2 = self.layer2[0](x1)
        #     x3 = self.layer3[0](self.mp3(x1) + x2)
        # else:
        #     x1 = self.layer1(x)
        #     x2 = self.layer2(x1)
        #     x3 = self.layer3(x2)

        assert self.summed == True
        # if eval == False:
        #     gate_1s = self.gate1(x)
        #     # average on the length dimension
        #     gate_1s = torch.mean(gate_1s, dim=1)
        #     x1 = torch.stack([self.layer1[i](x) for i in range(args.num_experts)])
        #     x1 = x1.permute(1, 0, 2, 3)  # (bs, num_experts, C, L)
        #     x1 = torch.einsum('be,becs->bcs', gate_1s, x1)
        #
        #     gate_2s = self.gate2(x1)
        #     gate_2s = torch.mean(gate_2s, dim=1)
        #     x2 = torch.stack([self.layer2[i](x1) for i in range(args.num_experts)])
        #     x2 = x2.permute(1, 0, 2, 3)
        #     x2 = torch.einsum('be,becs->bcs', gate_2s, x2)
        #
        #     gate_3s = self.gate3(x2)
        #     gate_3s = torch.mean(gate_3s, dim=1)
        #     x3 = torch.stack([self.layer3[i](self.mp3(x1) + x2) for i in range(args.num_experts)])
        #     x3 = x3.permute(1, 0, 2, 3)
        #     x3 = torch.einsum('be,becs->bcs', gate_3s, x3)
        # else:
        gate_1s = self.gate1(x)
        gate_1s = torch.mean(gate_1s, dim=1)
        max_value, idx = torch.max(gate_1s, dim=1)
        x1 = torch.stack([self.layer1[idx[i]](x[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)

        gate_2s = self.gate2(x1)
        gate_2s = torch.mean(gate_2s, dim=1)
        max_value, idx = torch.max(gate_2s, dim=1)
        x2 = torch.stack([self.layer2[idx[i]](x1[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)

        gate_3s = self.gate3(x2)
        gate_3s = torch.mean(gate_3s, dim=1)
        max_value, idx = torch.max(gate_3s, dim=1)
        x3 = torch.stack([self.layer3[idx[i]](self.mp3(x1[i].unsqueeze(0)) + x2[i].unsqueeze(0)) for i in
                              range(batch_size)]).squeeze(1)

        # if eval == False:
        #     gate_4s = self.gate4(torch.cat((self.mp3(x1), x2, x3), dim=1))
        #     gate_4s = torch.mean(gate_4s, dim=1)
        #     x = torch.stack([self.layer4[i](torch.cat((self.mp3(x1), x2, x3), dim=1)) for i in range(args.num_experts)])
        #     x = x.permute(1, 0, 2, 3)
        #     x = torch.einsum('be,becs->bcs', gate_4s, x)
        # else:
        gate_4s = self.gate4(torch.cat((self.mp3(x1), x2, x3), dim=1))
        gate_4s = torch.mean(gate_4s, dim=1)
        max_value, idx = torch.max(gate_4s, dim=1)
        x = torch.stack([self.layer4[idx[i]](
                torch.cat((self.mp3(x1[i].unsqueeze(0)), x2[i].unsqueeze(0), x3[i].unsqueeze(0)), dim=1)) for i in
                range(batch_size)]).squeeze(1)
        # x = self.layer4[0](torch.cat((self.mp3(x1), x2, x3), dim=1))
        x = self.relu(x)

        t = x.size()[-1]

        if self.context:
            global_x = torch.cat(
                (
                    x,
                    torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
                    torch.sqrt(
                        torch.var(x, dim=2, keepdim=True).clamp(
                            min=1e-4, max=1e4
                        )
                    ).repeat(1, 1, t),
                ),
                dim=1,
            )
        else:
            global_x = x
        # if eval == False:
        #     gate_attention = self.gate_attention(global_x)
        #     gate_attention = torch.mean(gate_attention, dim=1)
        #     w = torch.stack([self.attention[i](global_x) for i in range(args.num_experts)])
        #     w = w.permute(1, 0, 2, 3)
        #     w = torch.einsum('be,becs->bcs', gate_attention, w)
        # else:
        gate_attention = self.gate_attention(global_x)
        gate_attention = torch.mean(gate_attention, dim=1)
        max_value, idx = torch.max(gate_attention, dim=1)
        w = torch.stack([self.attention[idx[i]](global_x[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)

        # w = self.attention[0](global_x)

        mu = torch.sum(x * w, dim=2)
        sg = torch.sqrt(
            (torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-4, max=1e4)
        )

        x = torch.cat((mu, sg), 1)

        x = self.bn5(x)

        # if eval == False:
        #     gate_outputs = self.gate_output(x)
        #     expert_outputs = torch.stack([self.mlp_experts[i](x) for i in range(args.num_experts)])
        #     x = torch.einsum('be,ebd->bd', gate_outputs, expert_outputs)
        # else:
        gate_outputs = self.gate_output(x)
        max_value, idx = torch.max(gate_outputs, dim=1)
        x = torch.stack([self.mlp_experts[idx[i]](x[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)

        # x = self.fc6(x)
        #
        # if self.out_bn:
        #     x = self.bn6(x)
        #
        # x = self.cls(x)

        return x


if __name__ == "__main__":
    model = RawNet3_detect_sparse_experts(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
                                   out_bn=True,
                                   block=Bottle2neck, model_scale=8, context=True, summed=True,
                                   n_cls=2)  # calculate the number of parameters
    a = torch.rand(8, 600 * 160)
    print(a)
    out = model(a, eval=False)
    print(out)
