#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import warnings
import itertools

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
# from sklearn.base import BaseEstimator
import matplotlib.pyplot as plt
import seaborn as sns

from collections import OrderedDict



# Define the autoencoder model
class Autoencoder(nn.Module):
    def __init__(self, input_dim, encoding_dim, hidden_layer_dims=[], activation_function=nn.ReLU, linear_layer_params={"bias":True}, activation_params={"inplace":True}):
        super(Autoencoder, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_layer_dims = hidden_layer_dims
        self.encoding_dim = encoding_dim
        self.activation_function = activation_function
        self.linear_layer_params = linear_layer_params
        self.activation_params = activation_params


        self.encoder = self._create_encoder()
        self.decoder = self._create_decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
    def _create_encoder(self):

        
        self.encoding_layer_sizes = [self.input_dim] + self.hidden_layer_dims + [self.encoding_dim]

        encoding_layers = []
        for i in range(len(self.encoding_layer_sizes)-1):

            encoding_layers.append(("EncodingLinear"+str(i+1), nn.Linear(self.encoding_layer_sizes[i], self.encoding_layer_sizes[i+1], **self.linear_layer_params)))
            if self.activation_function is not nn.Linear:
                encoding_layers.append(("Encoding"+self.activation_function.__name__+str(i+1), self.activation_function(**self.activation_params)))
        

        return nn.Sequential(OrderedDict(encoding_layers))

    def _create_decoder(self):

    
        self.decoding_layer_sizes = list(reversed(self.encoding_layer_sizes))

        decoding_layers = []
        for i in range(len(self.decoding_layer_sizes)-1):

            decoding_layers.append(("DecodingLinear"+str(i+1), nn.Linear(self.decoding_layer_sizes[i], self.decoding_layer_sizes[i+1], **self.linear_layer_params)))
            if self.activation_function is not nn.Linear and i is not len(self.decoding_layer_sizes)-2:
                decoding_layers.append(("Decoding"+self.activation_function.__name__+str(i+1), self.activation_function(**self.activation_params)))
        

        return nn.Sequential(OrderedDict(decoding_layers))

class AE_anomaly_detector:
    def __init__(self, encoding_dim, hidden_layer_dims=[], activation_function=nn.ReLU, linear_layer_params={"bias":True}, activation_params={"inplace":True}, learning_rate=0.01, epochs=50, batch_size=None, monitor_loss_functions=None):
        """
        Initialize the autoencoder parameters.

        Parameters:
        input_dim : int
            Dimensionality of the input data.
        encoding_dim : int, default=2
            Dimensionality of the encoded representation.
        hidden_dim : int, default=4
            Number of units in the hidden layer.
        learning_rate : float, default=0.01
            Learning rate for the optimizer.
        epochs : int, default=50
            Number of epochs for training.
        """
        #self.input_dim = input_dim
        self.encoding_dim = encoding_dim
        self.hidden_layer_dims = hidden_layer_dims
        self.learning_rate = learning_rate
        self.activation_function = activation_function
        self.linear_layer_params = linear_layer_params
        self.activation_params = activation_params
        self.batch_size = batch_size
        self.epochs = epochs
        self.monitor_loss_functions = monitor_loss_functions


        self.losses = []
        self.autoencoder = None
        self._monitor_losses_calculated = False
        self._is_fitted = False

    def fit(self, X, y=None):
        """
        Fit the autoencoder to the input data.

        Parameters:
        X : {array-like, sparse matrix}, shape (n_samples, n_features)
            Training data.

        Returns:
        self : object
            Returns self.
        """
        
        
        self.input_dim = X.shape[1]
        model = Autoencoder(input_dim=self.input_dim, encoding_dim=self.encoding_dim, hidden_layer_dims=self.hidden_layer_dims, activation_function=self.activation_function, linear_layer_params=self.linear_layer_params, activation_params=self.activation_params)

        # Define loss function and optimizer
        criterion = nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate)

        # instantiate loss vectors if applicable
        if y is None and self.monitor_loss_functions is not None:
            warnings.warn("Monitor losses are specified, but no 'y' was supplied, so they will not be calculated.")

        if y is not None:
            if self.monitor_loss_functions is not None:
                if self.batch_size is None:
                    self.monitor_losses = {}
                    for loss_function_name in self.monitor_loss_functions:
                        self.monitor_losses[loss_function_name] = np.zeros((self.epochs,))
                else:
                    raise NotImplementedError("Monitoring additional loss functions is not implemented when a manual batch size is set.")
            else:
                warnings.warn("Argument 'y' is not used during (unsupervised) training. It is only used to monitor external losses by setting the 'monitor_loss_functions' argument during initialization.")


        if self.batch_size is not None:
            # Convert input data to PyTorch tensors
            X_tensor = torch.tensor(X, dtype=torch.float32)
            dataset = TensorDataset(X_tensor)
            dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

            # Train the autoencoder
            for epoch in range(self.epochs):
                epoch_loss = 0.0
                for X_batch in dataloader:
                    X_batch = X_batch[0]
                    output = model(X_batch)
                    loss = criterion(output, X_batch)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    epoch_loss += loss.item() * X_batch.shape[0]

                # Calculate average loss for the epoch
                epoch_loss /= X_tensor.shape[0]

                # Append the average loss to the list of losses
                self.losses.append(epoch_loss)
        else:
            # Convert input data to PyTorch tensors
            X_tensor = torch.tensor(X, dtype=torch.float32)

            # Train the autoencoder
            for epoch in range(self.epochs):
                output = model(X_tensor)
                loss = criterion(output, X_tensor)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                self.losses.append(loss.item())

                # Calculate additional losses for monitoring: (if applicable)
                with torch.no_grad():
                    if y is not None and self.monitor_loss_functions is not None:
                        for loss_function_name in self.monitor_loss_functions:
                            self.monitor_losses[loss_function_name][epoch] = self.monitor_loss_functions[loss_function_name](X_tensor, y, output)
            if y is not None and self.monitor_loss_functions is not None:
                self._monitor_losses_calculated = True

        self.autoencoder = model

        self._is_fitted = True

        return self

    def plot_loss(self, plot_additional_losses=False):
        """
        Plot the loss curve.

        Returns:
        None
        """

        if plot_additional_losses and not hasattr(self, "monitor_losses"):
            warnings.warn("No additional losses were monitored, defaulting to normal plotting of loss.")

        if plot_additional_losses and hasattr(self, "monitor_losses"):


            palette = itertools.cycle(sns.color_palette())

            plot_x_range = range(1, self.epochs + 1)
            fig, ax1 = plt.subplots()
    
            # Plot the first vector
            p1, = ax1.plot(plot_x_range, self.losses, label="Overall MSE", color=next(palette))
            ax1.set_ylabel("overall MSE", color=p1.get_color())
            ax1.tick_params(axis='y', labelcolor=p1.get_color())
            
            # Initialize the first axis
            axes = [ax1]
            
            # Create additional axes
            for i, monitor_loss_name in enumerate(self.monitor_losses):
                # Create a new y-axis
                temp_loss = self.monitor_losses[monitor_loss_name]
                ax = ax1.twinx()
                
                # Offset the new y-axis to the right
                ax.spines['right'].set_position(('outward', 60 * i))
                p_i, = ax.plot(plot_x_range, temp_loss, label=monitor_loss_name, color=next(palette))
                ax.set_ylabel(monitor_loss_name, color=p_i.get_color())
                ax.tick_params(axis='y', labelcolor=p_i.get_color())

                # Add the new axis to the list
                axes.append(ax)

            
            lines_labels = [ax.get_legend_handles_labels() for ax in axes]
            lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
            ax1.legend(lines, labels)

        else:
            plt.plot(range(1, self.epochs + 1), self.losses)
            plt.xlabel('Epoch')
            plt.ylabel('Loss (MSE)')
            plt.title('Loss Curve')
            plt.xlim(1, self.epochs + 1)
            plt.tight_layout()

    def predict(self, X):
        """
        Predict reconstruction loss for the input data.

        Parameters:
        X : {array-like, sparse matrix}, shape (n_samples, n_features)
            Samples.

        Returns:
        anomaly_scores : array, shape (n_samples,)
            Reconstruction loss for each sample.
        """
        # Convert input data to PyTorch tensors
        X_tensor = torch.tensor(X, dtype=torch.float32)

        # Compute reconstruction errors
        with torch.no_grad():
            reconstructed = self.autoencoder(X_tensor)
            reconstruction_errors = torch.mean((X_tensor - reconstructed)**2, dim=1)

        return reconstruction_errors.numpy()

    
    def fit_predict(self, X):
        """
        Fit the model to the training data and then predict on the same data.

        Parameters:
        X : {array-like, sparse matrix}, shape (n_samples, n_features)
            Training data.
        y : array-like, shape (n_samples,)
            Target values.

        Returns:
        y_pred : array, shape (n_samples,)
            Reconstruction loss predicted on the same data.
        """
        self.fit(X)
        return self.predict(X)
    

    def __sklearn_is_fitted__(self):
        """
        Check fitted status and return a Boolean value.
        """
        return hasattr(self, "_is_fitted") and self._is_fitted
