from abc import ABCMeta, abstractmethod

import pickle
from dataclasses import dataclass
from typing import Optional, Callable

import torch
import torch.nn
import torch.optim
from torch.cuda.amp import autocast, GradScaler

from XXX.uib.modules.encoder_decoder import EncoderDecoder, CategoricalEncoderDecoder
from XXX.progress_bar import with_progress_bar

from XXX.uib.losses.cross_entropies import (
    CrossEntropies,
    deterministic_continuous_encoder_cross_entropies,
    deterministic_categorical_encoder_loss,
    deterministic_categorical_encoder_cross_entropies,
    stochastic_continuous_encoder_cross_entropies,
)
from experiments.models import stochastic_model


from experiments.models.stochastic_model import StochasticModel


@dataclass
class Output:
    prediction: torch.Tensor
    cross_entropies: CrossEntropies

    latent: Optional[torch.Tensor] = None
    loss: Optional[torch.Tensor] = None
    decoder_loss: Optional[torch.Tensor] = None


class LatentExtractor:
    model: torch.nn.Module
    latent_layer_name: str
    stop_gradients: bool
    hook: torch.utils.hooks.RemovableHandle
    noise_injector: Optional[Callable]

    current_latent: torch.Tensor

    def __init__(self, model, *, layer_name=None, stop_gradients=None, noise_injector: Optional[Callable] = None):
        self.noise_injector = noise_injector
        self.model = model
        self.latent_layer_name = layer_name or ""
        self.stop_gradients = stop_gradients or False

        # install hook
        named_modules = dict(self.model.named_modules())
        if self.latent_layer_name in named_modules:
            module = named_modules[self.latent_layer_name]
            self.hook = module.register_forward_hook(self.hook_callback)
        else:
            raise ValueError(f"Could not find {layer_name} in {named_modules}!")

    def remove(self):
        self.hook.remove()

    def hook_callback(self, module, input: torch.Tensor, output: torch.Tensor):
        new_output = output

        if self.noise_injector:
            new_output = self.noise_injector(output)

        if stochastic_model.in_stochastic_run():
            latent = stochastic_model.unflatten_tensor(new_output, stochastic_model.get_current_num_samples())
        else:
            latent = new_output
        self.current_latent = latent

        if self.stop_gradients and new_output.requires_grad:
            return new_output.detach().requires_grad_(True)
        if new_output is not output:
            return new_output


class DynamicsInterface(metaclass=ABCMeta):
    model: torch.nn.Module
    optimizer: torch.optim.Optimizer
    loss_function: torch.nn.functional.cross_entropy

    def __init__(self, model, optimizer, loss_function):
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function

    def clone(self):
        return pickle.loads(pickle.dumps(self))

    @abstractmethod
    def predict(self, images, labels):
        pass

    @abstractmethod
    def fit(self, images, labels):
        pass

    @abstractmethod
    def pre_validation(self, train_loader):
        pass


class Dynamics(DynamicsInterface):
    def __init__(self, model, optimizer, loss_function):
        super().__init__(model, optimizer, loss_function)

    def predict(self, images, labels) -> Output:
        with torch.no_grad():
            prediction = self.model(images)
            loss = self.loss_function(prediction, labels)

            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction, cross_entropies, loss=loss)

    def fit(self, images, labels) -> Output:
        self.optimizer.zero_grad()
        prediction = self.model(images)

        loss = self.loss_function(prediction, labels)
        loss.backward()

        with torch.no_grad():
            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        self.optimizer.step()
        return Output(prediction.detach(), cross_entropies, loss=loss.detach())

    def pre_validation(self, train_loader):
        pass


class StochasticContinuousDynamics(DynamicsInterface):
    """
    Works with either `stochastic_continuous_encoder_decoder_cross_entropy` or
    `stochastic_continuous_encoder_prediction_cross_entropy` as loss.
    """

    latent_extractor: Optional[LatentExtractor]

    def __init__(self, model, optimizer, loss_function, *, latent_extractor: Optional[LatentExtractor] = None):
        wrapped_model = stochastic_model.as_stochastic_model(model)

        super().__init__(wrapped_model, optimizer, loss_function)

        self.latent_extractor = latent_extractor

    def clone(self):
        return pickle.loads(pickle.dumps(self))

    def _get_current_latent(self):
        latent = None
        if self.latent_extractor:
            latent = self.latent_extractor.current_latent.detach()
        return latent

    def predict(self, images, labels) -> Output:
        with torch.no_grad():
            prediction = self.model(images)
            loss = self.loss_function(prediction, labels)

            cross_entropies = stochastic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction.mean(dim=1), cross_entropies, loss=loss, latent=self._get_current_latent())

    def fit(self, images, labels) -> Output:
        self.optimizer.zero_grad()
        prediction = self.model(images)

        loss = self.loss_function(prediction, labels)
        loss.backward()

        self.optimizer.step()

        with torch.no_grad():
            cross_entropies = stochastic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(
            prediction.detach().mean(dim=1), cross_entropies, loss=loss.detach(), latent=self._get_current_latent()
        )

    def pre_validation(self, train_loader):
        pass


def full_stochastic_loss(latent_x_k_z, prediction_x_k_y, labels_x):
    raise NotImplementedError


class StochasticContinuousFullLossDynamics(DynamicsInterface):
    latent_extractor: LatentExtractor
    loss_function: full_stochastic_loss

    def __init__(self, model, optimizer, loss_function, *, latent_extractor: LatentExtractor):
        wrapped_model = model
        if not isinstance(wrapped_model, StochasticModel):
            wrapped_model = stochastic_model.as_stochastic_model(model)

        super().__init__(wrapped_model, optimizer, loss_function)

        self.latent_extractor = latent_extractor

        self.scaler = GradScaler()

    def clone(self):
        return pickle.loads(pickle.dumps(self))

    def _get_current_latent(self):
        latent = None
        if self.latent_extractor:
            latent = self.latent_extractor.current_latent
        return latent

    def predict(self, images, labels) -> Output:
        with torch.no_grad():
            # NOTE: Disabled autocase during test time, as the benefits are marginal (unlike for training).
            # Also, I'm getting +inf in my losses in Epoch 0 (not a big deal usually, but I don't want to branch).
            #with autocast():
            prediction = self.model(images)

            latent = self._get_current_latent()
            loss = self.loss_function(latent, prediction, labels)

            cross_entropies = stochastic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction.mean(dim=1), cross_entropies, loss=loss, latent=latent)

    def fit(self, images, labels) -> Output:
        self.optimizer.zero_grad()

        with autocast():
            prediction = self.model(images)
            latent = self._get_current_latent()

            loss = self.loss_function(latent, prediction, labels)

        self.scaler.scale(loss).backward()

        self.scaler.step(self.optimizer)

        self.scaler.update()

        with torch.no_grad():
            cross_entropies = stochastic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction.detach().mean(dim=1), cross_entropies, loss=loss.detach(), latent=latent.detach())

    def pre_validation(self, train_loader):
        pass


class EncodingLossDynamics(DynamicsInterface):
    """Dynamics that work together with an EncoderDecoder class.

        Here, we train """

    model: EncoderDecoder

    def predict(self, images, labels) -> Output:
        with torch.no_grad():
            encoding = self.model.encoder(images)
            prediction = self.model.decoder(encoding)
            loss = self.loss_function(prediction, labels)

            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction, cross_entropies, latent=encoding, loss=loss)

    def fit(self, images, labels) -> Output:
        self.optimizer.zero_grad()

        encoding = self.model.encoder(images)
        loss = self.loss_function(encoding, labels)

        loss.backward()

        with torch.no_grad():
            prediction = self.model.decoder(encoding)

        self.optimizer.step()

        with torch.no_grad():
            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction.detach(), cross_entropies, latent=encoding.detach(), loss=loss.detach())

    def pre_validation(self, train_loader):
        self.model.decoder.reset()

        with torch.no_grad():
            images: torch.Tensor
            for i in range(1):
                for images, labels in with_progress_bar(train_loader):
                    images = images
                    labels = labels

                    self.model.fit_decoder(images, labels)


class DecodingLossDynamics(DynamicsInterface):
    model: EncoderDecoder

    def predict(self, images, labels) -> Output:
        with torch.no_grad():
            encoding = self.model.encoder(images)
            prediction = self.model.decode(encoding)
            loss = self.loss_function(prediction, labels)

            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction, cross_entropies, latent=encoding, loss=loss)

    def fit(self, images, labels) -> Output:
        self.optimizer.zero_grad()

        encoding = self.model.encoder(images)
        prediction = self.model.decode(encoding)
        loss = self.loss_function(prediction, labels)

        loss.backward()

        self.optimizer.step()

        with torch.no_grad():
            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction.detach(), cross_entropies, latent=encoding.detach(), loss=loss.detach())

    def pre_validation(self, train_loader):
        pass


class FullLossDynamics(DynamicsInterface):
    model: EncoderDecoder

    def predict(self, images, labels) -> Output:
        with torch.no_grad():
            encoding = self.model.encoder(images)
            prediction = self.model.decode(encoding)
            loss = self.loss_function(encoding, prediction, labels)

            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction, cross_entropies, latent=encoding, loss=loss)

    def fit(self, images, labels) -> Output:
        self.optimizer.zero_grad()

        encoding = self.model.encoder(images)
        prediction = self.model.decode(encoding)
        loss = self.loss_function(encoding, prediction, labels)

        loss.backward()

        self.optimizer.step()

        with torch.no_grad():
            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction.detach(), cross_entropies, latent=encoding.detach(), loss=loss.detach())

    def pre_validation(self, train_loader):
        pass


class TwoLossesDynamics(DynamicsInterface):
    model: EncoderDecoder
    decoder_loss: torch.nn.functional.cross_entropy

    def __init__(self, model, optimizer, *, loss_function, decoder_loss):
        super().__init__(model, optimizer, loss_function)
        self.decoder_loss = decoder_loss

    def predict(self, images, labels) -> Output:
        with torch.no_grad():
            encoding = self.model.encoder(images)
            prediction = self.model.decode(encoding)

            loss = self.loss_function(encoding, labels)
            decoder_loss = self.decoder_loss(prediction, labels)

            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(prediction, cross_entropies, latent=encoding, loss=loss, decoder_loss=decoder_loss)

    def fit(self, images, labels) -> Output:
        self.optimizer.zero_grad()

        encoding = self.model.encoder(images)
        loss = self.loss_function(encoding, labels)
        loss.backward()

        encoding_stopped_grad = encoding.detach().requires_grad_(False)
        prediction = self.model.decode(encoding_stopped_grad)
        decoder_loss = self.decoder_loss(prediction, labels)
        decoder_loss.backward()

        self.optimizer.step()

        with torch.no_grad():
            cross_entropies = deterministic_continuous_encoder_cross_entropies(prediction, labels)

        return Output(
            prediction.detach(),
            cross_entropies,
            latent=encoding.detach(),
            loss=loss.detach(),
            decoder_loss=decoder_loss.detach(),
        )

    def pre_validation(self, train_loader):
        pass


class CategoricalEncoderFullLossDynamics(DynamicsInterface):
    """
    For use with `deterministic_categorical_encoder_prediction_cross_entropy` or
    `deterministic_categorical_encoder_decoder_cross_entropy` (or similar).

    DOES NOT RETRAIN THE DECODER.
    """

    model: CategoricalEncoderDecoder
    loss_function: deterministic_categorical_encoder_loss

    def predict(self, inputs_x_, labels_x) -> Output:
        with torch.no_grad():
            latent_x_z = self.model.encoder(inputs_x_)

            decoder_z_y = self.model.get_decoder_prob_z_y(latent_x_z)

            loss = self.loss_function(latent_x_z, decoder_z_y, labels_x)

            cross_entropies = deterministic_categorical_encoder_cross_entropies(latent_x_z, decoder_z_y, labels_x)

            prediction_X_Y = self.model.combine(latent_x_z, decoder_z_y)

        return Output(prediction_X_Y.float(), cross_entropies, latent=latent_x_z, loss=loss)

    def fit(self, inputs_x_, labels_x) -> Output:
        self.optimizer.zero_grad()

        latent_x_z = self.model.encoder(inputs_x_)

        # If we are not using gradients to fit the decoder, fit it now.
        self.model.decoder.fit(latent_x_z, labels_x)

        decoder_z_y = self.model.get_decoder_prob_z_y(latent_x_z)

        loss = self.loss_function(latent_x_z, decoder_z_y, labels_x)
        loss.backward()

        self.optimizer.step()

        with torch.no_grad():
            cross_entropies = deterministic_categorical_encoder_cross_entropies(latent_x_z, decoder_z_y, labels_x)

            prediction_X_Y = self.model.combine(latent_x_z, decoder_z_y)

        return Output(prediction_X_Y, cross_entropies, latent=latent_x_z.detach(), loss=loss.detach())

    def pre_validation(self, train_loader):
        pass


class CategoricalEncoderEncodingLossDynamics(DynamicsInterface):
    """Dynamics that work together with an `CategoricalEncoderDecoder` class."""

    model: CategoricalEncoderDecoder

    def predict(self, images, labels) -> Output:
        with torch.no_grad():
            latent_x_z = self.model.encoder(images)
            loss = self.loss_function(latent_x_z, labels)

            decoder_z_y = self.model.get_decoder_prob_z_y(latent_x_z)

            cross_entropies = deterministic_categorical_encoder_cross_entropies(latent_x_z, decoder_z_y, labels)

            prediction_X_Y = self.model.combine(latent_x_z, decoder_z_y)

        return Output(prediction_X_Y, cross_entropies, latent=latent_x_z, loss=loss)

    def fit(self, images, labels) -> Output:
        self.optimizer.zero_grad()

        latent_x_z = self.model.encoder(images)
        loss = self.loss_function(latent_x_z, labels)

        loss.backward()
        self.optimizer.step()

        # If we are not using gradients to fit the decoder, fit it now.
        self.model.decoder.fit(latent_x_z, labels)

        with torch.no_grad():
            decoder_z_y = self.model.get_decoder_prob_z_y(latent_x_z)

            prediction_X_Y = self.model.combine(latent_x_z, decoder_z_y)

            cross_entropies = deterministic_categorical_encoder_cross_entropies(latent_x_z, decoder_z_y, labels)

        return Output(prediction_X_Y.detach(), cross_entropies, latent=latent_x_z.detach(), loss=loss.detach())

    def pre_validation(self, train_loader):
        self.model.decoder.reset()

        with torch.no_grad():
            images: torch.Tensor
            for images, labels in with_progress_bar(train_loader):
                images = images
                labels = labels

                self.model.fit_decoder(images, labels)
