#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import warnings

import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

import os

# ConvAE for Mnist-like input
class ConvAutoencoder(nn.Module):
    def __init__(self, input_channels, encoding_dim):
        super(ConvAutoencoder, self).__init__()

        self.input_channels = input_channels
        self.encoding_dim = encoding_dim

        self.encoder = self._create_encoder()
        self.bottleneck = nn.Linear(32 * 7 * 7, encoding_dim)  # Adjusted for 28x28 input
        self.decoder_input = nn.Linear(encoding_dim, 32 * 7 * 7)
        self.decoder = self._create_decoder()

    def _create_encoder(self):
        return nn.Sequential(
            nn.Conv2d(self.input_channels, 16, kernel_size=3, stride=2, padding=1),  # Output: 16 x 14 x 14 (for 28x28 input)
            nn.ReLU(True),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),   # Output: 32 x 7 x 7
            nn.ReLU(True)
        )

    def _create_decoder(self):
        return nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # Output: 16 x 14 x 14
            nn.ReLU(True),
            nn.ConvTranspose2d(16, self.input_channels, kernel_size=3, stride=2, padding=1, output_padding=1), # Output: input_channels x 28 x 28
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.bottleneck(x)  # Bottleneck
        x = self.decoder_input(x)  # Fully connected layer to reshape the data
        x = x.view(x.size(0), 32, 7, 7)  # Reshape to match the decoder input
        x = self.decoder(x)
        return x

class ConvAE_anomaly_detector:
    def __init__(self, input_channels, encoding_dim, learning_rate=0.01, epochs=50, save_epochs=None, ID=None, checkpoint_loss=None):
        self.input_channels = input_channels
        self.encoding_dim = encoding_dim
        self.learning_rate = learning_rate
        self.epochs = epochs   
        self.save_epochs = save_epochs if save_epochs is not None else []
        self.ID = ID
        self.checkpoint_loss = checkpoint_loss

        self.losses = []
        self.autoencoder = ConvAutoencoder(input_channels=self.input_channels, encoding_dim=self.encoding_dim)
        self._is_fitted = False

        self.minimum_checkpoint_loss_epoch = None
        self.additional_losses = None

    def fit(self, X, additional_loaders=None):
        #additional_loaders must be dict of loaders

        #Make dict of additional losses
        if additional_loaders is not None:
            additional_losses = {key:[] for key in additional_loaders}


        # Define loss function and optimizer
        loss_function = nn.MSELoss(reduction="sum")
        optimizer = torch.optim.SGD(self.autoencoder.parameters(), lr=self.learning_rate)

        dataloader = X

        # Train the autoencoder
        for epoch in range(self.epochs):
            epoch_loss = 0.0
            save_this_epoch = False

            for X_batch in dataloader:
                X_batch = X_batch[0]
                output = self.autoencoder(X_batch)
                loss = loss_function(output, X_batch)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            last_batch_size = X_batch.shape[0]
            n_samples = (len(dataloader)-1)*dataloader.batch_size+last_batch_size
            # Calculate average loss for the epoch
            epoch_loss /= n_samples

            # Append the average loss to the list of losses
            self.losses.append(epoch_loss)

            # Calculate additional losses over external datasets:
            with torch.no_grad():

                for additional_loader_name in additional_loaders:
                    additional_loader = additional_loaders[additional_loader_name]

                    temp_additional_loss = 0.0

                    for X_batch in additional_loader:
                        X_batch = X_batch[0]
                        output = self.autoencoder(X_batch)
                        loss = loss_function(output, X_batch)
                        temp_additional_loss += loss.item()


                    last_batch_size = X_batch.shape[0]
                    n_samples = (len(additional_loader)-1)*additional_loader.batch_size+last_batch_size
                    # Calculate average loss for the epoch
                    temp_additional_loss /= n_samples
                    additional_losses[additional_loader_name].append(temp_additional_loss)

                    if additional_loader_name == self.checkpoint_loss and temp_additional_loss == min(additional_losses[additional_loader_name]):
                        self.minimum_checkpoint_loss_epoch = epoch
                        save_this_epoch = True


            if save_this_epoch or epoch in self.save_epochs:
                os.makedirs(os.path.join("saved_models", self.ID), exist_ok=True)
                self.save_model(os.path.join("saved_models", self.ID, f"epoch_{epoch}.pth"))
            
        self.additional_losses = additional_losses
        self._is_fitted = True

        return self

    def plot_loss(self, plot_additionals=None):

        if plot_additionals is not None:
            plt.plot(range(1, self.epochs + 1), self.losses, label="Train")
            for additional_dataset_name in plot_additionals:
                plt.plot(range(1, self.epochs + 1), self.additional_losses[additional_dataset_name], label=additional_dataset_name)
            plt.legend()
        else:
            plt.plot(range(1, self.epochs + 1), self.losses)
        plt.xlabel('Epoch')
        plt.ylabel('Loss (MSE)')
        plt.xlim(1, self.epochs + 1)
        plt.tight_layout()

    def predict(self, X):
        X_tensor = torch.tensor(X, dtype=torch.float32)

        with torch.no_grad():
            reconstructed = self.autoencoder(X_tensor)
            reconstruction_errors = torch.mean((X_tensor - reconstructed) ** 2, dim=[1, 2, 3])

        return reconstruction_errors.numpy()

    def fit_predict(self, X):
        self.fit(X)
        return self.predict(X)
    
    def save_model(self, filename):
        torch.save(self.autoencoder.state_dict(), filename)
        print(f"Model saved to {filename}")
    
    def load_model(self, filename):
        if os.path.exists(filename):
            self.autoencoder.load_state_dict(torch.load(filename))
            self._is_fitted = True #only (partially) fitted models are ever saved
            print(f"Model loaded from {filename}")
        else:
            print(f"Model file {filename} not found.")

    def __sklearn_is_fitted__(self):
        return hasattr(self, "_is_fitted") and self._is_fitted


class LatentSpacePredictor(nn.Module):
    def __init__(self, detector):
        super().__init__()

        if hasattr(detector, "_is_fitted") and detector._is_fitted:
            self.detector = detector
            self._is_fitted = True
        else:
            ValueError("Autoencoder used for initialization is not fitted.")

    #Only exists for compatability with sklearn
    def fit(self):
        pass

    def predict(self, X):
        X_tensor = torch.tensor(X, dtype=torch.float32)
        with torch.no_grad():
            artificial_samples = self.detector.autoencoder.decoder_input(X_tensor)
            artificial_samples = artificial_samples.view(artificial_samples.size(0), 32, 7, 7)
            artificial_samples = self.detector.autoencoder.decoder(artificial_samples)
            reconstructed_samples = self.detector.autoencoder(artificial_samples)
            
            n_samples = artificial_samples.shape[0]
            loss = np.zeros((n_samples,))
            for i in range(n_samples):
                loss[i] = nn.MSELoss(reduction="sum")(reconstructed_samples[i], artificial_samples[i]).numpy()
        return loss
    
    
    def __sklearn_is_fitted__(self):
        """
        Check fitted status and return a Boolean value.
        """
        return hasattr(self, "_is_fitted") and self._is_fitted
