import torch
import torch.nn as nn

from .base import BaseDivergence


class MaxAffineNet(nn.Module):
    def __init__(self, input_dim, K=10, subnet_width=100, linear_subnet=False):
        super().__init__()
        self.input_dim = input_dim
        self.K = K
        self.subnet_width = subnet_width
        self.linear_subnet = linear_subnet
        self.subnets = nn.ModuleList([self._init_subnet() for _ in range(K)])

    def _init_subnet(self):
        if self.linear_subnet:
            subnet = nn.Linear(self.input_dim, 1)
        else:
            subnet = nn.Sequential(
                nn.Linear(self.input_dim, self.subnet_width),
                nn.ReLU(),
                nn.Linear(self.subnet_width, 1)
            )
        return subnet

    def forward(self, x):
        out = torch.cat([fn(x) for fn in self.subnets], dim=1)
        out_max, ind_max = torch.max(out, keepdim=True, dim=1)
        return ind_max, out_max

    def gather_by_idx(self, x, q_star):
        out = torch.cat([fn(x) for fn in self.subnets], dim=1)
        out_q_star = torch.gather(out, 1, q_star)
        return out_q_star

    def gather_mat(self, x, q_star):
        out = torch.cat([fn(x) for fn in self.subnets], dim=1)
        return out[:, q_star.flatten()]


class MaxAffineDivergence(BaseDivergence):
    def __init__(self, phi, **kwargs):
        super().__init__(phi, **kwargs)

    def compute_mat(self, x, y):        
        _, max_net_p = self.phi(x)
        q_star, _ = self.phi(y)
        out_q_star = self.phi.gather_mat(x, q_star)

        return max_net_p - out_q_star

    def pairwise_distance(self, x, y):
        _, max_net_p = self.phi(x)
        q_star, _ = self.phi(y)
        out_q_star = self.phi.gather_by_idx(x, q_star)

        return (max_net_p - out_q_star).flatten()

    def compute_full_mat(self, x, y):
        return self.batch_compute_mat(x, y)
