import torch.nn as nn
from torch.nn.functional import normalize
import torch


class Encoder(nn.Module):
    def __init__(self, input_dim, feature_dim):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, 2000),
            nn.ReLU(),
            nn.Linear(2000, feature_dim),
        )

    def forward(self, x):
        return self.encoder(x)


class Decoder(nn.Module):
    def __init__(self, input_dim, feature_dim):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(feature_dim, 2000),
            nn.ReLU(),
            nn.Linear(2000, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, input_dim)
        )

    def forward(self, x):
        return self.decoder(x)


class Network(nn.Module):
    def __init__(self, view, input_size, feature_dim,
                class_num, device, batch_size):
        super(Network, self).__init__()
        self.encoders = []
        self.decoders = []
        self.avg = (torch.ones((batch_size, batch_size))
                    * (1 / batch_size)).to(device)

        self.batch_size = batch_size
        self.device = device
        self.class_num = class_num
        self.anchor_num = class_num

        for v in range(view):
            self.encoders.append(
                Encoder(input_size[v], feature_dim).to(device))
            self.decoders.append(
                Decoder(input_size[v], feature_dim).to(device))


        self.encoders = nn.ModuleList(self.encoders)
        self.decoders = nn.ModuleList(self.decoders)

        self.feature_contrastive_module = nn.Sequential(
            nn.Linear(feature_dim, self.anchor_num),
            nn.BatchNorm1d(self.anchor_num),
        )
        self.view = view
        self.C = nn.Parameter(torch.randn(self.anchor_num, feature_dim))

    def forward(self, xs, max_view=-1):
        hs = []
        zs_pre_align = []
        hs_align = []
        xrs = []
        zs = []
        zs_pre = []
        size_x = len(xs[0])
        for v in range(self.view):
            x = xs[v]
            z = self.encoders[v](x)
            h = normalize(self.feature_contrastive_module(z), dim=1)
            h = self.activate_and_normalize(h)
            z_pre_align = 0
            h_align = 0
            z_pre = torch.mm(h, self.C)
            if max_view >= 0:
                z_pre_t = normalize(z_pre, dim=1)
                z_pre_t = self.activate_and_normalize(z_pre_t)
                p = (torch.ones((size_x, size_x))
                    * (1 / size_x)).to(self.device)

                if v != max_view:
                    z_pre_align = torch.mm(p, z_pre_t)
                    h_align = torch.mm(p, h)
                else:
                    z_pre_align = z_pre_t
                    h_align = h

            xr = self.decoders[v](z_pre)
            hs.append(h)
            zs_pre_align.append(z_pre_align)
            hs_align.append(h_align)
            zs.append(z)
            xrs.append(xr)
            zs_pre.append(z_pre)
        return hs, xrs, zs, zs_pre, zs_pre_align, hs_align


    def activate_and_normalize(self, tensor):
        tensor = torch.clamp(tensor, min=0)
        row_sums = tensor.sum(dim=1, keepdim=True)
        row_sums = row_sums + (row_sums == 0).float()
        tensor = tensor / row_sums
        return tensor


    def get_C(self):
        return self.C.clone()


    def forward_plot(self, xs):
        zs = []
        hs = []
        for v in range(self.view):
            x = xs[v]
            z = self.encoders[v](x)
            zs.append(z)
            h = self.feature_contrastive_module(z)
            hs.append(h)
        return zs, hs
