# -*- encoding: utf-8 -*-

import torch
import torch.nn as nn
from asteroid_filterbanks import Encoder, ParamSincFB

from RawNetBasicBlock import Bottle2neck, PreEmphasis
from hyperparameters import args


class RawNet3_detect_sparse_experts(nn.Module):
    def __init__(self, block, model_scale, context, summed, C=1024, n_cls=2, **kwargs):
        super().__init__()

        nOut = kwargs["nOut"]

        self.context = context
        self.encoder_type = kwargs["encoder_type"]
        self.log_sinc = kwargs["log_sinc"]
        self.norm_sinc = kwargs["norm_sinc"]
        self.out_bn = kwargs["out_bn"]
        self.summed = summed

        self.preprocess = nn.Sequential(
            PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
        )

        self.gate_conv1 = nn.Sequential(nn.Linear(96000, args.num_experts), nn.Softmax(dim=-1))
        self.conv1 = nn.ModuleList([Encoder(
            ParamSincFB(
                C // 4,
                251,
                stride=kwargs["sinc_stride"],
            )
        ) for _ in range(args.num_experts)])
        self.relu = nn.ReLU()
        # self.bn1 = nn.BatchNorm1d(C // 4)

        self.gate1 = nn.Sequential(nn.Linear(9575, args.num_experts), nn.Softmax(dim=-1))
        self.layer1 = nn.ModuleList([block(
            C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5
        ) for _ in range(args.num_experts)])
        self.gate2 = nn.Sequential(nn.Linear(1915, args.num_experts), nn.Softmax(dim=-1))
        self.layer2 = nn.ModuleList([block(
            C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3
        ) for _ in range(args.num_experts)])
        self.gate3 = nn.Sequential(nn.Linear(638, args.num_experts), nn.Softmax(dim=-1))
        self.layer3 = nn.ModuleList([block(
            C, C, kernel_size=3, dilation=4, scale=model_scale) for _ in range(args.num_experts)])

        self.gate4 = nn.Sequential(nn.Linear(638, args.num_experts), nn.Softmax(dim=-1))
        self.layer4 = nn.ModuleList([nn.Conv1d(3 * C, 1536, kernel_size=1) for _ in range(args.num_experts)])

        if self.context:
            attn_input = 1536 * 3
        else:
            attn_input = 1536
        print("self.encoder_type", self.encoder_type)
        if self.encoder_type == "ECA":
            attn_output = 1536
        elif self.encoder_type == "ASP":
            attn_output = 1
        else:
            raise ValueError("Undefined encoder")

        self.gate_attention = nn.Sequential(nn.Linear(638, args.num_experts), nn.Softmax(dim=-1))
        self.attention = nn.ModuleList([nn.Sequential(
            nn.Conv1d(attn_input, 128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, attn_output, kernel_size=1),
            nn.Softmax(dim=2),
        ) for _ in range(args.num_experts)])

        self.bn5 = nn.BatchNorm1d(3072)

        self.gate_output = nn.Sequential(
            nn.Linear(3072, args.num_experts),
            nn.Softmax(dim=1)
        )

        self.mlp_experts = nn.ModuleList([nn.Sequential(
            nn.Linear(3072, 3072 * 4),
            nn.ReLU(),
            nn.Linear(3072 * 4, 3072),
            nn.ReLU(),
            nn.BatchNorm1d(3072),
            nn.Linear(3072, nOut),
            nn.BatchNorm1d(nOut),
            nn.ReLU(),
            nn.Linear(nOut, 2),
        ) for _ in range(args.num_experts)])
        # self.fc6 = nn.Linear(3072, nOut)
        # self.bn6 = nn.BatchNorm1d(nOut)

        self.mp3 = nn.MaxPool1d(3)

        self.cls = nn.Linear(nOut, n_cls)

    def forward(self, x, eval=False):
        """
        :param x: input mini-batch (bs, samp)
        """

        batch_size = x.size(0)
        with torch.cuda.amp.autocast(enabled=False):
            x = self.preprocess(x)

            # if eval == False:
            #     gate_conv1 = self.gate_conv1(x)
            #     gate_conv1 = torch.mean(gate_conv1, dim=1)
            #     x = torch.stack([self.conv1[i](x) for i in range(args.num_experts)])
            #     x = x.permute(1, 0, 2, 3)
            #     x = torch.einsum('be,becd->bcd', gate_conv1, x)
            #     x = torch.abs(x)
            # else:
            gate_conv1 = self.gate_conv1(x)
            gate_conv1 = torch.mean(gate_conv1, dim=1)
            max_value, idx = torch.max(gate_conv1, dim=1)
            x = torch.stack([self.conv1[idx[i]](x[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)
            x = torch.abs(x)

            # x = torch.abs(self.conv1[0](x))
            if self.log_sinc:
                x = torch.log(x + 1e-6)
            if self.norm_sinc == "mean":
                x = x - torch.mean(x, dim=-1, keepdim=True)
            elif self.norm_sinc == "mean_std":
                m = torch.mean(x, dim=-1, keepdim=True)
                s = torch.std(x, dim=-1, keepdim=True)
                s[s < 0.001] = 0.001
                x = (x - m) / s
        #
        # if self.summed:
        #     x1 = self.layer1[0](x)
        #     x2 = self.layer2[0](x1)
        #     x3 = self.layer3[0](self.mp3(x1) + x2)
        # else:
        #     x1 = self.layer1(x)
        #     x2 = self.layer2(x1)
        #     x3 = self.layer3(x2)

        assert self.summed == True
        # if eval == False:
        #     gate_1s = self.gate1(x)
        #     # average on the length dimension
        #     gate_1s = torch.mean(gate_1s, dim=1)
        #     x1 = torch.stack([self.layer1[i](x) for i in range(args.num_experts)])
        #     x1 = x1.permute(1, 0, 2, 3)  # (bs, num_experts, C, L)
        #     x1 = torch.einsum('be,becs->bcs', gate_1s, x1)
        #
        #     gate_2s = self.gate2(x1)
        #     gate_2s = torch.mean(gate_2s, dim=1)
        #     x2 = torch.stack([self.layer2[i](x1) for i in range(args.num_experts)])
        #     x2 = x2.permute(1, 0, 2, 3)
        #     x2 = torch.einsum('be,becs->bcs', gate_2s, x2)
        #
        #     gate_3s = self.gate3(x2)
        #     gate_3s = torch.mean(gate_3s, dim=1)
        #     x3 = torch.stack([self.layer3[i](self.mp3(x1) + x2) for i in range(args.num_experts)])
        #     x3 = x3.permute(1, 0, 2, 3)
        #     x3 = torch.einsum('be,becs->bcs', gate_3s, x3)
        # else:
        gate_1s = self.gate1(x)
        gate_1s = torch.mean(gate_1s, dim=1)
        max_value, idx = torch.max(gate_1s, dim=1)
        x1 = torch.stack([self.layer1[idx[i]](x[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)

        gate_2s = self.gate2(x1)
        gate_2s = torch.mean(gate_2s, dim=1)
        max_value, idx = torch.max(gate_2s, dim=1)
        x2 = torch.stack([self.layer2[idx[i]](x1[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)

        gate_3s = self.gate3(x2)
        gate_3s = torch.mean(gate_3s, dim=1)
        max_value, idx = torch.max(gate_3s, dim=1)
        x3 = torch.stack([self.layer3[idx[i]](self.mp3(x1[i].unsqueeze(0)) + x2[i].unsqueeze(0)) for i in
                              range(batch_size)]).squeeze(1)

        # if eval == False:
        #     gate_4s = self.gate4(torch.cat((self.mp3(x1), x2, x3), dim=1))
        #     gate_4s = torch.mean(gate_4s, dim=1)
        #     x = torch.stack([self.layer4[i](torch.cat((self.mp3(x1), x2, x3), dim=1)) for i in range(args.num_experts)])
        #     x = x.permute(1, 0, 2, 3)
        #     x = torch.einsum('be,becs->bcs', gate_4s, x)
        # else:
        gate_4s = self.gate4(torch.cat((self.mp3(x1), x2, x3), dim=1))
        gate_4s = torch.mean(gate_4s, dim=1)
        max_value, idx = torch.max(gate_4s, dim=1)
        x = torch.stack([self.layer4[idx[i]](
                torch.cat((self.mp3(x1[i].unsqueeze(0)), x2[i].unsqueeze(0), x3[i].unsqueeze(0)), dim=1)) for i in
                range(batch_size)]).squeeze(1)
        # x = self.layer4[0](torch.cat((self.mp3(x1), x2, x3), dim=1))
        x = self.relu(x)

        t = x.size()[-1]

        if self.context:
            global_x = torch.cat(
                (
                    x,
                    torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
                    torch.sqrt(
                        torch.var(x, dim=2, keepdim=True).clamp(
                            min=1e-4, max=1e4
                        )
                    ).repeat(1, 1, t),
                ),
                dim=1,
            )
        else:
            global_x = x
        # if eval == False:
        #     gate_attention = self.gate_attention(global_x)
        #     gate_attention = torch.mean(gate_attention, dim=1)
        #     w = torch.stack([self.attention[i](global_x) for i in range(args.num_experts)])
        #     w = w.permute(1, 0, 2, 3)
        #     w = torch.einsum('be,becs->bcs', gate_attention, w)
        # else:
        gate_attention = self.gate_attention(global_x)
        gate_attention = torch.mean(gate_attention, dim=1)
        max_value, idx = torch.max(gate_attention, dim=1)
        w = torch.stack([self.attention[idx[i]](global_x[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)

        # w = self.attention[0](global_x)

        mu = torch.sum(x * w, dim=2)
        sg = torch.sqrt(
            (torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-4, max=1e4)
        )

        x = torch.cat((mu, sg), 1)

        x = self.bn5(x)

        # if eval == False:
        #     gate_outputs = self.gate_output(x)
        #     expert_outputs = torch.stack([self.mlp_experts[i](x) for i in range(args.num_experts)])
        #     x = torch.einsum('be,ebd->bd', gate_outputs, expert_outputs)
        # else:
        gate_outputs = self.gate_output(x)
        max_value, idx = torch.max(gate_outputs, dim=1)
        x = torch.stack([self.mlp_experts[idx[i]](x[i].unsqueeze(0)) for i in range(batch_size)]).squeeze(1)

        # x = self.fc6(x)
        #
        # if self.out_bn:
        #     x = self.bn6(x)
        #
        # x = self.cls(x)

        return x


if __name__ == "__main__":
    model = RawNet3_detect_sparse_experts(encoder_type='ECA', nOut=256, sinc_stride=10, log_sinc=True, norm_sinc=True,
                                   out_bn=True,
                                   block=Bottle2neck, model_scale=8, context=True, summed=True,
                                   n_cls=2)  # calculate the number of parameters
    a = torch.rand(8, 600 * 160)
    print(a)
    out = model(a, eval=False)
    print(out)
