from typing import Tuple, Dict, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from tslearn.metrics import SoftDTWLossPyTorch

from data.preprocessing import Pipeline
from metrics.cross_correlation import get_cross_correlation
from model.diffusion import Diffusion
from model.pl_modules.logger import Logger


class DDM(Logger):

    def __init__(
        self, path_storage: str, save_plots_in_file: bool, file_extension: str, save_synthetic_data: bool,
        n_samples_evaluation: int, save_plots_in_wandb: bool,
        batch_size: int, lr: float, seq_len: int, max_lag: int,
        use_loss_weights: bool, loss_fn_sparsification: Optional[Union[nn.L1Loss, nn.MSELoss]],
        loss_fn_fourier: Optional[nn.MSELoss], loss_fn_conv: Optional[nn.MSELoss],
        loss_fn_dtw: Optional[SoftDTWLossPyTorch],
        diffusion: Diffusion, denoising_network: nn.Module,
        feature_names: List[str], pipeline: Optional[Pipeline],
    ) -> None:
        super().__init__(
            path_storage, save_plots_in_file, file_extension, save_synthetic_data,
            n_samples_evaluation, save_plots_in_wandb, feature_names, max_lag
        )
        self.save_hyperparameters(
            ignore=[
                'denoising_network', 'diffusion',
                'loss_fn_sparsification', 'loss_fn_fourier', 'loss_fn_conv', 'loss_fn_dtw'
            ]
        )

        self.n_samples_evaluation = n_samples_evaluation

        self.batch_size = batch_size
        self.lr = lr
        self.seq_len = seq_len

        self.use_loss_weights = use_loss_weights
        self.loss_fn_sparsification = loss_fn_sparsification
        self.loss_fn_fourier = loss_fn_fourier
        self.loss_fn_conv = loss_fn_conv
        self.loss_fn_dtw = loss_fn_dtw

        self.diffusion = diffusion
        self.diffusion_timesteps = diffusion.diffusion_timesteps

        self.denoising_network = denoising_network

        self.n_features: int = len(feature_names)
        self.pipeline = pipeline

        if use_loss_weights:
            reduction = 'none'
            if diffusion.target == 'noise':
                self.loss_weights = (diffusion.betas ** 2 / (2. * diffusion.posterior_variance * diffusion.alphas *
                                                             (1. - diffusion.alphas_cumprod)))
            else:
                self.loss_weights = .5 * torch.sqrt(diffusion.alphas_cumprod) / (2. - diffusion.alphas_cumprod)
        else:
            reduction = 'mean'
        self.loss_fn_reconstruction = nn.MSELoss(reduction=reduction)

    def get_loss_reconstruction(self, x_0_pred: Tensor, x_0: Tensor, t: Tensor) -> Tensor:
        loss = self.loss_fn_reconstruction(x_0_pred, x_0)
        if self.use_loss_weights:
            loss = torch.mean(
                torch.einsum(
                    'abc,a->abc', loss, self.loss_weights[t]
                )
            )
        return loss

    def get_loss_fourier(self, pred: Tensor, real: Tensor) -> Tensor:
        fft_real = torch.fft.fft(real, dim=1, norm='forward')
        fft_pred = torch.fft.fft(pred, dim=1, norm='forward')
        loss = (self.loss_fn_fourier(torch.real(fft_pred), torch.real(fft_real)) +
                self.loss_fn_fourier(torch.imag(fft_pred), torch.imag(fft_real)))
        return loss

    def get_loss_conv(self, x_0_reconstructed: Tensor, x_0: Tensor) -> Tensor:
        corrs_real = get_cross_correlation(x_0, padding=self.seq_len)
        corrs_synthetic = get_cross_correlation(x_0_reconstructed, padding=self.seq_len)
        loss = self.loss_fn_conv(corrs_synthetic, corrs_real)
        return loss

    def get_loss_dtw(self, x_0_reconstructed: Tensor, x_0: Tensor) -> Tensor:
        loss = torch.mean(self.loss_fn_dtw(x_0_reconstructed, x_0))
        return loss

    def get_losses(
        self, x_0: Tensor, x_t: Tensor, real: Tensor, pred: Tensor, t: Tensor, coefficients: Tensor
    ) -> Dict[str, Tensor]:
        losses = dict()

        loss_reconstruction = self.get_loss_reconstruction(pred, real, t)
        losses['loss_reconstruction'] = 10. * loss_reconstruction

        if self.loss_fn_sparsification:
            loss_sparsification = self.loss_fn_sparsification(coefficients, torch.zeros_like(coefficients))
            losses['loss_sparsification'] = loss_sparsification

        if self.loss_fn_fourier:
            loss_fourier = self.get_loss_fourier(pred, real)
            losses['loss_fourier'] = 100. * loss_fourier

        if self.loss_fn_conv or self.loss_fn_dtw:
            x_0_reconstructed = self.diffusion.predict_start_from_noise(x_t, pred, t) \
                if self.diffusion.target == 'noise' else pred

            if self.loss_fn_conv:
                loss_conv = self.get_loss_conv(x_0_reconstructed, x_0)
                losses['loss_conv'] = loss_conv

            if self.loss_fn_dtw and (x_0.shape[-1] != 36 or self.current_epoch > 0):
                loss_dtw = self.get_loss_dtw(x_0_reconstructed, x_0)
                losses['loss_dtw'] = .01 * loss_dtw

        return losses

    def forward_denoising_network(self, x_t: Tensor, t: Tensor) -> Tensor:
        # x_t.shape = [batch_size, seq_len, n_features]
        # t.shape = [batch_size]
        return self.denoising_network(x_t, t.float())

    def forward(self, batch: dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Tensor, Tensor, Tensor]:
        x_0 = batch["x_0"]
        # x_0.shape = [batch_size, seq_len, n_features]

        t = torch.randint(0, self.diffusion_timesteps, size=(self.batch_size,), device=self.device, dtype=torch.long)
        # t.shape = [batch_size]

        x_t, noise_real = self.diffusion.forward_diffusion(x_0, t)
        # x_t.shape = noise_real.shape = [batch_size, seq_len, n_features]

        pred, coefficients = self.forward_denoising_network(x_t, t)
        # pred.shape = [batch_size, seq_len, n_features]

        real = noise_real if self.diffusion.target == 'noise' else x_0
        losses = self.get_losses(x_0, x_t, real, pred, t, coefficients)

        return losses, t, real, pred

    def training_step(self, batch: dict[str, Tensor]) -> Tensor:
        losses, _, _, _ = self(batch)
        for loss_name, loss_value in losses.items():
            self.log(f'train_{loss_name}', loss_value, prog_bar=True)
        total_loss = sum(losses.values())
        self.log(f'train_total_loss', total_loss)
        return total_loss

    def validation_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
        # self.log('lr', self.lr_schedulers().get_last_lr()[0])
        losses, t, real, pred = self(batch)
        # real.shape = pred.shape = [batch_size, seq_len, n_features]

        for loss_name, loss_value in losses.items():
            self.log(f'val_{loss_name}', loss_value, prog_bar=True)
        total_loss = sum(losses.values())
        self.log(f'val_total_loss', total_loss)

        # if batch_idx == 0:
        #     t = t.detach().cpu().numpy()[:self.n_samples_evaluation]
        #     real = real.detach().cpu().numpy()[:self.n_samples_evaluation]
        #     pred = pred.detach().cpu().numpy()[:self.n_samples_evaluation]
        #     self.evaluate_prediction(real, pred, t)

    def on_validation_epoch_end(self) -> None:
        if self.current_epoch >= 40:
            x_t = torch.randn((self.batch_size, self.seq_len, self.n_features), device=self.device)
            synthetic, coefficients = self.predict_step(x_t)
            self.evaluate_synthetic(synthetic[:self.n_samples_evaluation], coefficients[:self.n_samples_evaluation])

    def predict_step(self, x_t: Tensor) -> (np.ndarray, np.ndarray):
        coefficients = None
        for i in reversed(range(self.diffusion_timesteps)):
            t = torch.full((self.batch_size,), i, device=self.device, dtype=torch.long)  # [B]
            pred, coefficients = self.forward_denoising_network(x_t, t)  # [B, seq_len, n_features]
            x_t = self.diffusion.backward_diffusion(x_t, pred, t, i)  # [B, seq_len, n_features]

        synthetic = x_t.detach().cpu().numpy()[:, self.max_lag:]  # [B, seq_len-max_lag, n_features]
        coefficients = coefficients.detach().cpu().numpy()

        # [n_samples_evaluation, n_features, n_features*max_lag, seq_len-max_lag]
        synthetic = self.pipeline.batched_inverse_transform(synthetic) if self.pipeline else synthetic
        # [B, seq_len-max_lag, n_features]

        return synthetic, coefficients

    def configure_optimizers(self) -> Dict:
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1., end_factor=.001, total_iters=20)
        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}\n"
            f"\tdiffusion={self.diffusion}\n"
            f"\tloss_noise={self.loss_fn_reconstruction}\n"
            f"\tdenoising_network={self.denoising_network}\n"
        )
