from omegaconf import DictConfig
import numpy as np
from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any
from einops import rearrange, reduce, repeat
from typing import List

from lightning.pytorch.utilities.types import STEP_OUTPUT

from algorithms.common.base_pytorch_algo import BasePytorchAlgo
from utils.logging_utils import get_validation_metrics_for_states, log_timeseries_plots
from algorithms.common.metrics import crps_quantile_sum
from .models.diffusion_transition import DiffusionTransitionModel


class DiffusionForcingPrediction(BasePytorchAlgo):
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        self.register_metadata(cfg.metadata)  # sets certain attributes based on dataset metadata, if any
        self.x_shape = [self.target_dimension] if hasattr(self, 'target_dimension') else cfg.x_shape
        self.z_shape = cfg.z_shape
        self.frame_stack = cfg.frame_stack
        self.cfg.diffusion.cum_snr_decay = self.cfg.diffusion.cum_snr_decay**self.frame_stack
        self.x_stacked_shape = list(self.x_shape)
        self.x_stacked_shape[0] *= cfg.frame_stack
        self.is_spatial = len(self.x_shape) == 3  # pixel
        self.gt_cond_prob = cfg.gt_cond_prob  # probability to condition one-step diffusion o_t+1 on ground truth o_t
        self.gt_first_frame = cfg.gt_first_frame
        self.context_frames = self.context_length if hasattr(self, 'context_length') else cfg.context_frames  # number of context frames at validation time
        self.chunk_size = cfg.chunk_size
        self.calc_crps_sum = cfg.calc_crps_sum
        self.external_cond_dim = cfg.external_cond_dim
        self.uncertainty_scale = cfg.uncertainty_scale
        self.sampling_timesteps = cfg.diffusion.sampling_timesteps
        # self.predict_residual = cfg.predict_residual  # fixme
        self.use_covariates = cfg.use_covariates
        self.validation_step_outputs = []
        self.min_crps_sum = float("inf")
        self.learnable_init_z = cfg.learnable_init_z
        if hasattr(self, 'lags_seq'):
            # Dataset is GluonTS dataset
            self.shifted_lags_seq = [x - 1 for x in self.lags_seq]
            if self.use_covariates:
                self.external_cond_dim = len(self.lags_seq) * self.x_shape[0] + self.time_features_dim * 2
            else:
                self.external_cond_dim = 0

        super().__init__(cfg)        

    def _build_model(self):
        self.transition_model = DiffusionTransitionModel(
            self.x_stacked_shape, self.z_shape, self.external_cond_dim, self.cfg.diffusion
        )
        self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std)
        if self.learnable_init_z:
            self.init_z = nn.Parameter(torch.randn(list(self.z_shape)), requires_grad=True)
        self.embed = nn.Embedding(num_embeddings=self.x_shape[0], embedding_dim=1) if self.use_covariates else None

    def configure_optimizers(self):
        transition_params = list(self.transition_model.parameters())
        if self.learnable_init_z:
            transition_params.append(self.init_z)
        optimizer_dynamics = torch.optim.AdamW(
            transition_params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay, betas=self.cfg.optimizer_beta
        )

        return optimizer_dynamics

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
        # update params
        optimizer.step(closure=optimizer_closure)

        # manually warm up lr without a scheduler
        if self.trainer.global_step < self.cfg.warmup_steps:
            lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.cfg.warmup_steps)
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * self.cfg.lr

    def _preprocess_batch(self, batch, validation=False):
        if isinstance(batch, dict):
            # Dataset is GluonTS dataset
            observations, normalized_covariates = self.get_observations_from_gluonts_dataset(batch)
            batch = [observations, normalized_covariates]

        # repeat batch for crps sum for time series prediction
        if validation and self.calc_crps_sum:
            batch = [d[None].expand(self.calc_crps_sum, *([-1] * len(d.shape))).flatten(0, 1) 
                     for d in batch if d is not None]
            
        xs = batch[0]
        batch_size, n_frames = xs.shape[:2]

        if n_frames % self.frame_stack != 0:
            raise ValueError("Number of frames must be divisible by frame stack size")
        if self.context_frames % self.frame_stack != 0:
            raise ValueError("Number of context frames must be divisible by frame stack size")

        masks = torch.ones(n_frames, batch_size).to(xs.device)
        n_frames = n_frames // self.frame_stack

        if self.external_cond_dim:
            conditions = batch[1]
            conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
            conditions = rearrange(conditions, "b (t fs) d -> t b (fs d)", fs=self.frame_stack).contiguous()
        else:
            conditions = [None for _ in range(n_frames)]

        xs = self._normalize_x(xs)
        xs = rearrange(xs, "b (t fs) c ... -> t b (fs c) ...", fs=self.frame_stack).contiguous()

        if self.learnable_init_z:
            init_z = self.init_z[None].expand(batch_size, *self.z_shape)
        else:
            init_z = torch.zeros(batch_size, *self.z_shape)
            init_z = init_z.to(xs.device)

        return xs, conditions, masks, init_z
     
    @staticmethod
    def get_lagged_subsequences(
        sequence: torch.Tensor,
        sequence_length: int,
        indices: List[int],
        subsequences_length: int = 1,
    ) -> tuple:
        """
        Adapted from source:
        https://github.com/zalandoresearch/pytorch-ts/blob/master/pts/model/time_grad/time_grad_network.py

        Returns lagged subsequences of a given sequence.
        Parameters
        ----------
        sequence
            the sequence from which lagged subsequences should be extracted.
            Shape: (N, T, C).
        sequence_length
            length of sequence in the T (time) dimension (axis = 1).
        indices
            list of lag indices to be used.
        subsequences_length
            length of the subsequences to be extracted.
        Returns
        --------
        lagged : Tensor
            a tensor of shape (N, S, C, I),
            where S = subsequences_length and I = len(indices),
            containing lagged subsequences.
            Specifically, lagged[i, :, j, k] = sequence[i, -indices[k]-S+j, :].
        """
        # we must have: history_length + begin_index >= 0
        # that is: history_length - lag_index - sequence_length >= 0
        # hence the following assert
        assert max(indices) + subsequences_length <= sequence_length, (
            f"lags cannot go further than history length, found lag "
            f"{max(indices)} while history length is only {sequence_length}"
        )
        assert all(lag_index >= 0 for lag_index in indices)

        lagged_values = []
        observations = None
        for lag_index in indices:
            begin_index = -lag_index - subsequences_length
            end_index = -lag_index if lag_index > 0 else None
            if lag_index == 0:
                observations = sequence[:, begin_index:end_index, ...]
            else:
                lagged_values.append(sequence[:, begin_index:end_index, ...].unsqueeze(1))
        return observations, torch.cat(lagged_values, dim=1).permute(0, 2, 3, 1)

    def get_observations_from_gluonts_dataset(self, batch):
        """
        Process batch of GLuonTS time series dataset to get inputs for the model. This includes
        standardizing the data and adding covariates if applicable. Covariates adapted from:
        https://github.com/zalandoresearch/pytorch-ts/blob/master/pts/model/time_grad/time_grad_network.py
        """

        past_time_feat = batch['past_time_feat']
        future_time_feat = batch['future_time_feat']
        past_target_cdf = batch['past_target_cdf']
        future_target_cdf = batch['future_target_cdf']
        target_dimension_indicator = batch['target_dimension_indicator']

        target_dim = past_target_cdf.shape[-1]
        assert target_dim == self.x_shape[0]

        if not self.use_covariates:
            if 0 in future_target_cdf.shape:
                # validation
                sequence = past_target_cdf
            else:
                # training
                sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
            observations = sequence[:, -(self.context_length + self.prediction_length):, ...]
            return observations, None

        if 0 in future_target_cdf.shape:
            # validation and test
            time_feat = past_time_feat[:, -(self.context_length + self.prediction_length):, ...]
            sequence = past_target_cdf
        else:
            # training
            time_feat = torch.cat(
                (past_time_feat[:, -self.context_length:, ...], future_time_feat),
                dim=1,
            )
            sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
        sequence_length = self.history_length + self.prediction_length
        subsequences_length = self.context_length + self.prediction_length

        # (batch_size, sub_seq_len, target_dim, num_lags)
        observations, lags = self.get_lagged_subsequences(
            sequence=sequence,
            sequence_length=sequence_length,
            indices=self.shifted_lags_seq,
            subsequences_length=subsequences_length,
        )

        num_input_lags = len(self.shifted_lags_seq) - 1
        input_lags = lags.reshape(
            (-1, subsequences_length, num_input_lags * target_dim)
        )

        # (batch_size, target_dim, embed_dim=1)
        index_embeddings = self.embed(target_dimension_indicator)

        # (batch_size, seq_len, target_dim)
        repeated_index_embeddings = (
            index_embeddings.unsqueeze(1)
            .expand(-1, subsequences_length, -1, -1)
            .reshape((-1, subsequences_length, target_dim))
        )

        # (batch_size, subsequences_length, input_dim)
        normalized_input_lags = self._normalize_x(input_lags, num_repeats=num_input_lags)
        normalized_covariates = torch.cat((normalized_input_lags, repeated_index_embeddings, time_feat), dim=-1)
        return observations, normalized_covariates

    def reweigh_loss(self, loss, weight=None):
        loss = rearrange(loss, "t b (fs c) ... -> t b fs c ...", fs=self.frame_stack)
        if weight is not None:
            expand_dim = len(loss.shape) - len(weight.shape) - 1
            weight = rearrange(weight, "(t fs) b ... -> t b fs ..." + " 1" * expand_dim, fs=self.frame_stack)
            loss = loss * weight

        return loss.mean()

    def training_step(self, batch, batch_idx):
        # training step for dynamics
        xs, conditions, masks, *_, init_z = self._preprocess_batch(batch)

        n_frames, batch_size, _, *_ = xs.shape

        xs_pred = []
        loss = []
        z = init_z
        cum_snr = None
        for t in range(0, n_frames):
            deterministic_t = None
            if random() <= self.gt_cond_prob or (t == 0 and random() <= self.gt_first_frame):
                deterministic_t = 0

            z_next, x_next_pred, l, cum_snr = self.transition_model(
                z, xs[t], conditions[t], deterministic_t=deterministic_t, cum_snr=cum_snr
            )

            z = z_next
            xs_pred.append(x_next_pred)
            loss.append(l)

        xs_pred = torch.stack(xs_pred)
        loss = torch.stack(loss)
        x_loss = self.reweigh_loss(loss, masks)
        loss = x_loss

        if batch_idx % 20 == 0:
            self.log_dict(
                {
                    "training/loss": loss,
                    "training/x_loss": x_loss,
                }
            )
            # self.log_gradient_stats()  # fixme

        xs = rearrange(xs, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack)
        xs_pred = rearrange(xs_pred, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack)

        output_dict = {
            "loss": loss,
            "xs_pred": self._unnormalize_x(xs_pred),
            "xs": self._unnormalize_x(xs),
        }

        return output_dict

    @torch.no_grad()
    def validation_step(self, batch, batch_idx, namespace="validation")  -> STEP_OUTPUT:
        xs, conditions, masks, *_, init_z = self._preprocess_batch(batch, validation=True)

        n_frames, batch_size, *_ = xs.shape
        xs_pred = []
        xs_pred_all = []
        z = init_z

        # context
        for t in range(0, self.context_frames // self.frame_stack):
            z, x_next_pred, _, _ = self.transition_model(z, xs[t], conditions[t], deterministic_t=0)
            xs_pred.append(x_next_pred)

        # prediction
        while len(xs_pred) < n_frames:
            if self.chunk_size > 0:
                horizon = min(n_frames - len(xs_pred), self.chunk_size)
            else:
                horizon = n_frames - len(xs_pred)

            chunk = [
                torch.randn((batch_size,) + tuple(self.x_stacked_shape), device=self.device) for _ in range(horizon)
            ]

            for i in range(self.sampling_timesteps + horizon * self.uncertainty_scale):
                if self.transition_model.return_all_timesteps:
                    xs_pred_all.append(chunk)

                z_chunk = z
                for t in range(horizon):
                    new_i = min(i - t * self.uncertainty_scale, self.sampling_timesteps - 1)
                    if new_i < 0:
                        break

                    chunk[t], z_chunk = self.transition_model.ddim_sample_step(
                        chunk[t], z_chunk, conditions[len(xs_pred) + t], new_i
                    )
            z = z_chunk
            xs_pred += chunk

        xs_pred = torch.stack(xs_pred)
        loss = F.mse_loss(xs_pred, xs, reduction="none")
        loss = self.reweigh_loss(loss, masks)

        xs = rearrange(xs, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack)
        xs_pred = rearrange(xs_pred, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack)

        xs = self._unnormalize_x(xs)
        xs_pred = self._unnormalize_x(xs_pred)

        if not self.is_spatial:
            if self.transition_model.return_all_timesteps:
                xs_pred_all = [torch.stack(item) for item in xs_pred_all]
                limit = self.transition_model.sampling_timesteps
                # sampling_timesteps = [2 ** i for i in range(int(np.log2(limit - 1)) + 1)] + [limit]
                for i in np.linspace(1, limit, 5, dtype=int):
                    xs_pred = xs_pred_all[i]
                    xs_pred = self._unnormalize_x(xs_pred)
                    metric_dict = get_validation_metrics_for_states(xs_pred, xs)
                    self.log_dict(
                        {f"{namespace}/{i}_sampling_steps_{k}": v for k, v in metric_dict.items()},
                        on_step=False,
                        on_epoch=True,
                        prog_bar=True,
                    )
            else:
                metric_dict = get_validation_metrics_for_states(xs_pred, xs)
                self.log_dict(
                    {f"{namespace}/{k}": v for k, v in metric_dict.items()},
                    on_step=False,
                    on_epoch=True,
                    prog_bar=True,
                )

        self.validation_step_outputs.append((xs_pred.detach().cpu(), xs.detach().cpu()))

        return loss

    def on_validation_epoch_end(self, namespace="validation", log_visualizations=False) -> None:
        if not self.validation_step_outputs:
            return

        # multiple trajectories sampled, compute CRPS_sum and visualize trajectories
        if self.calc_crps_sum:
            all_preds = []
            all_gt = []
            for pred, gt in self.validation_step_outputs:
                all_preds.append(pred.view(pred.shape[0], self.calc_crps_sum, -1, *pred.shape[2:]))
                all_gt.append(gt.view(gt.shape[0], self.calc_crps_sum, -1, *gt.shape[2:]))
            all_preds = torch.cat(all_preds, 2).float().permute(1, 0, 2, 3)
            gt = torch.cat(all_gt, 2).float()[:, 0]
            crps_sum_val = crps_quantile_sum(all_preds[:, self.context_frames:], gt[self.context_frames:])
            self.min_crps_sum = min(self.min_crps_sum, crps_sum_val)
            self.log_dict(
                {f"{namespace}/crps_sum": crps_sum_val, f"{namespace}/min_crps_sum": self.min_crps_sum},
                on_step=False,
                on_epoch=True,
                prog_bar=True,
            )
            if log_visualizations:
                log_timeseries_plots(all_preds, gt, self.context_frames, namespace, self.trainer.global_step, 
                                     self.frequency)
        self.validation_step_outputs.clear()

    def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        return self.validation_step(*args, **kwargs, namespace="test")
    
    def on_test_epoch_end(self) -> None:
        self.min_crps_sum = float("inf")  # reset min CRPS_sum
        self.on_validation_epoch_end(namespace="test", log_visualizations=True)

    def _normalize_x(self, xs, num_repeats=1):
        # num_repeats is useful to normalize covariates, where lag sequences are stacked
        mean = torch.repeat_interleave(self.data_mean, num_repeats)
        std = torch.repeat_interleave(self.data_std, num_repeats)
        shape = [1] * (xs.ndim - mean.ndim) + list(mean.shape)
        mean = mean.reshape(shape)
        std = std.reshape(shape)
        return (xs - mean) / std

    def _unnormalize_x(self, xs, num_repeats=1):
        # num_repeats is useful to normalize covariates, where lag sequences are stacked
        mean = torch.repeat_interleave(self.data_mean, num_repeats)
        std = torch.repeat_interleave(self.data_std, num_repeats)
        shape = [1] * (xs.ndim - mean.ndim) + list(mean.shape)
        mean = mean.reshape(shape)
        std = std.reshape(shape)
        return xs * std + mean
