import math

import torch
import torch.nn.functional as F
from torch import nn


class DartsGenotype(nn.Module):
    def __init__(self, filters, PS=8):
        super().__init__()

        self.filters = filters
        self.PS = PS

        self.K = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(filters, self.filters, bias=False),
                    nn.LeakyReLU(),
                    nn.Linear(self.filters, self.filters, bias=False),
                )
                for _ in range(4)
            ]
        )
        self.Q = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear((block_id + 2) * filters, self.filters, bias=False),
                    nn.LeakyReLU(),
                    nn.Linear(self.filters, self.filters, bias=False),
                )
                for block_id in range(4)
            ]
        )

    def compute_coeffs(self, k_states, block_id, att_states):
        if self.PS is None:
            PS = k_states.shape[-1]
        else:
            PS = self.PS

        B, P, C, H, W = k_states.shape
        HP, WP = H // PS, W // PS
        k = k_states.reshape(B, P, C, HP, PS, WP, PS)
        k = k.permute(0, 3, 5, 1, 2, 4, 6)  # Shape: (B, HP, WP, P, C, PS, PS)
        k = k.reshape(B * HP * WP * P, C, PS, PS)
        k = F.avg_pool2d(k, (PS, PS)).squeeze()
        k = self.K[block_id](k)
        k = k.reshape(B * HP * WP, P, self.filters)

        def ps_to_q(s):
            QPS = PS
            if s.shape[-1] != W:
                QPS *= 2
            q = s.reshape(B, C, HP, QPS, WP, QPS)
            q = q.permute(0, 2, 4, 1, 3, 5)
            q = q.reshape(B * HP * WP, C, QPS, QPS)
            q = F.avg_pool2d(q, (QPS, QPS)).squeeze()
            return q

        att_states = [ps_to_q(p_state) for p_state in att_states]
        q = torch.cat(att_states, dim=1)
        q = self.Q[block_id](q)
        q = q.reshape(B * HP * WP, self.filters, 1)

        qk = torch.bmm(k, q).squeeze() / math.sqrt(k.shape[-1])
        qk = torch.clamp(qk, min=-20, max=20)

        # Generate all possible pairs (op, source)
        # We have N=BLOCK_ID+2 groups, each has M=NUM_OPS vectors
        # The following code ensures that we pick at most one op per input
        NEW_B, P, N = qk.shape[0], qk.shape[1], block_id + 2

        combinations = torch.combinations(torch.arange(P), r=2).to(k.device)
        idx1 = combinations[:, 0]
        idx2 = combinations[:, 1]

        # Candidate operations in the same group <=> they work on the same input
        group = torch.arange(P) % N
        mask = group[idx1] != group[idx2]

        idx1 = idx1[mask]
        idx2 = idx2[mask]

        qk1 = qk[:, idx1]
        qk2 = qk[:, idx2]

        qk = qk1 + qk2  # [NEW_B, P*(P+1)/2 - N*M*(M+1)/2]

        many_coeffs = F.softmax(qk, dim=-1)  # [NEW_B, P*(P+1)/2 - N*M*(M+1)/2]

        coeffs = torch.zeros(NEW_B, P, device=qk.device)

        idx1_expanded = idx1.unsqueeze(0).expand(NEW_B, -1)
        idx2_expanded = idx2.unsqueeze(0).expand(NEW_B, -1)

        coeffs.scatter_add_(1, idx1_expanded, many_coeffs)
        coeffs.scatter_add_(1, idx2_expanded, many_coeffs)

        coeffs = coeffs.reshape(B, HP, WP, P)
        coeffs = coeffs.mean(dim=[1, 2])

        return coeffs


class NasBenchGenotype(nn.Module):
    def __init__(self, filters, PS=4):
        super().__init__()

        self.filters = filters
        self.PS = PS

        self.K = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(filters, filters, bias=False),
                    nn.LeakyReLU(),
                    nn.Linear(filters, filters, bias=False),
                )
                for _ in range(4)
            ]
        )

        self.Q = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear((block_id + 1) * filters, filters, bias=False),
                    nn.LeakyReLU(),
                    nn.Linear(filters, filters, bias=False),
                )
                for block_id in range(4)
            ]
        )

    def compute_coeffs(self, k_states, block_id, att_states):
        if self.PS is None:
            PS = k_states.shape[-1]
        else:
            PS = self.PS

        B, P, C, H, W = k_states.shape
        HP, WP = H // PS, W // PS
        k = k_states.reshape(B, P, C, HP, PS, WP, PS)
        k = k.permute(0, 3, 5, 1, 2, 4, 6)  # Shape: (B, HP, WP, P, C, PS, PS)
        k = k.reshape(B * HP * WP * P, C, PS, PS)
        k = F.avg_pool2d(k, (PS, PS)).squeeze()
        k = self.K[block_id](k)
        k = k.reshape(B * HP * WP, P, self.filters)

        def ps_to_q(s):
            QPS = PS
            if s.shape[-1] != W:
                QPS *= 2
            q = s.reshape(B, C, HP, QPS, WP, QPS)
            q = q.permute(0, 2, 4, 1, 3, 5)
            q = q.reshape(B * HP * WP, C, QPS, QPS)
            q = F.avg_pool2d(q, (QPS, QPS)).squeeze()
            return q

        att_states = [ps_to_q(p_state) for p_state in att_states]
        q = torch.cat(att_states, dim=1)
        q = self.Q[block_id](q)
        q = q.reshape(B * HP * WP, self.filters, 1)

        qk = torch.bmm(k, q).squeeze() / math.sqrt(k.shape[-1])
        qk = torch.clamp(qk, min=-20, max=20)

        N = block_id + 1
        coeffs = torch.empty_like(qk)
        for i in range(N):
            indices = [i + j * N for j in range(5)]  # one operation per input
            coeffs[:, indices] = F.softmax(qk[:, indices], dim=-1)

        coeffs = coeffs.reshape(B, HP, WP, P)
        coeffs = coeffs.mean(dim=[1, 2])

        return coeffs
