import torch.nn as nn
from utils import model_loader
from models.autoencoder import Autoencoder
from torch.utils import data
import torch


class AutoEncoderSOM(nn.Module):

    def __init__(self, d_in=1, hw_in=28, som_input=2, n_max=20, at=0.985, lr0=0.1, lr=0.1, lr_push=1.0,
                 ds_beta=0.5, eps_ds=1., ld=0.05, gamma=3.0, device='cpu', semi=False):
        super(AutoEncoderSOM, self).__init__(),

        self.som_input_size = som_input
        self.d_in = d_in
        self.hw_in = hw_in
        self.device = device

        self.n_max = n_max

        params = {'at': at, 'ds_beta': ds_beta, 'lr0': lr0, 'lr': lr, 'lr_push': lr_push,
                  'eps_ds': eps_ds, 'ld': ld, 'gamma': gamma}
        params = model_loader.load_som_params(params)

        self.semi = semi
        self.som = model_loader.choose_som(input_size=self.som_input_size,
                                           n_max=self.n_max,
                                           param_set=params,
                                           device=device,
                                           semi=self.semi)

        self.som = self.som.to(self.device)

        autoencoder = Autoencoder(input_flatten=(self.hw_in*self.hw_in*self.d_in), latent_space_input=self.som_input_size)
        self.encoder = autoencoder.encoder
        self.decoder = autoencoder.decoder

    def autoencoder_extract_features(self, x):
        encoded_features = self.encoder(x)
        return encoded_features.view(-1, self.som_input_size)

    def forward(self, x, y=None):
        encoded_features = self.encoder(x)
        decoded_features = self.decoder(encoded_features)

        if y is None:
            som_output = self.som(encoded_features.view(-1, self.som_input_size))
        else:
            som_output = self.som(encoded_features.view(-1, self.som_input_size), y)

        return encoded_features, decoded_features, som_output

    def cluster(self, dataloader):
        extracted_features = []
        labels = []
        for batch_idx, (samples, targets) in enumerate(dataloader):
            samples, targets = samples.to(self.device), targets.to(self.device)
            outputs = self.autoencoder_extract_features(samples)

            sample_cpu = outputs[0].to('cpu').data.numpy()
            label_cpu = int(targets[0].to('cpu').data.numpy())

            extracted_features.append(sample_cpu)
            labels.append(label_cpu)

        extracted_features = torch.tensor(extracted_features)
        labels = torch.tensor(labels)
        extracted_dataset = data.TensorDataset(extracted_features, labels)
        extracted_dataloader = data.DataLoader(extracted_dataset)

        predicted_labels = None
        if self.semi:
            predicted_clusters, predicted_labels, true_labels, cluster_result = self.som.cluster(extracted_dataloader)
        else:
            predicted_clusters, true_labels, cluster_result = self.som.cluster(extracted_dataloader)

        return predicted_clusters, predicted_labels, true_labels, cluster_result

    def write_output(self, output_path, cluster_result):
        self.som.write_output(output_path, cluster_result)
