import copy
from typing import Dict, Tuple

import torch
import torch.nn as nn
from pydantic import BaseModel

from utils.data import GaussianDiffusion
from utils.models import init_model


def init_predictor(config: Dict) -> nn.Module:
    """
    Initializes the specified predictor.

    Args:
        config (Dict): A dictionary containing the predictor configuration.

    Returns:
        The initialized predictor object.

    Raises:
        NotImplementedError: If the specified predictor name in the config is not supported.
    """

    if config['predictor_name'] == 'default':
        return Predictor(config)
    else:
        raise NotImplementedError(f"predictor_name ({config['predictor_name']}) not implemented.")


class Predictor(nn.Module):
    class Config(BaseModel):
        # predictor
        predictor_name: str = 'default'

        # model
        model_name: str = 'mlp'
        input_size: int = 61
        output_size: int = 60
        hidden_sizes: tuple = (512, 512, 512, 512, 512)

        # g_diff
        n_steps: int = 1000
        beta_first: float = 1e-4
        beta_last: float = 2e-2

    default_config = Config().dict()

    def __init__(self, config):
        super().__init__()
        self.config = self.Config(**config).dict()

        # attributes
        self.n_steps = config['n_steps']
        self.beta_first = config['beta_first']
        self.beta_last = config['beta_last']

        self.g_diff = GaussianDiffusion(config)

        # when loading a trained model these will be in the config
        if 'data_means' in config:
            self.data_means = config['data_means']
        if 'data_stds' in config:
            self.data_stds = config['data_stds']

        # init model
        self.criterion = nn.MSELoss()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = init_model(config).to(self.device)

    def forward(self, x):
        return self.model(x)

    def train_step(self, x, y):
        return {'loss': self.criterion(self(x), y)}

    def val_step(self, x, y):
        return {'loss': self.criterion(self(x), y)}

    def guided_sample(self, x_prompt: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Samples the diffusion model using RePaint. The method includes steps for data normalization and denormalization.

        Args:
            x_prompt (torch.Tensor): The input prompt tensor of shape (batch_size, input_length, n_sequences).

        Returns:
            A tuple containing:
                - The sampled response (torch.Tensor) of shape (batch_size, input_length, n_sequences).
                - A dictionary containing information for other unused metrics.
        """

        batch_size, input_length, n_sequences = x_prompt.shape
        x_prompt = copy.deepcopy(x_prompt)

        # normalise
        x_prompt -= self.data_means
        x_prompt /= self.data_stds

        # sample
        x_prompt = x_prompt.reshape(batch_size, -1).float()
        x, info = self.g_diff.guided_sample(self.model, x_prompt)

        x = x.reshape(batch_size, input_length, n_sequences)

        # un-normalise
        x *= self.data_stds
        x += self.data_means

        return x, info

    def compute_combined_error(self, x_sampled: torch.Tensor) -> torch.Tensor:
        """
        Computes the combined error for the given prompt-response pairs.

        Args:
            x_sampled (torch.Tensor): The sampled prompt-response pairs (batch_size, input_length, n_sequences).

        Returns:
            A torch.Tensor containing the combined error for each sample in the batch (shape: (batch_size,)).
        """

        bs, _, _ = x_sampled.shape
        x_sampled = copy.deepcopy(x_sampled)

        x_sampled -= self.data_means
        x_sampled /= self.data_stds
        x_sampled = x_sampled.reshape(bs, -1)

        return self.g_diff.compute_combined_error(self.model, x_sampled)

    def get_eps(self, x_sampled: torch.Tensor) -> torch.Tensor:
        """
        Computes the denoising epsilon for the provided sample on the final time-step

        Args:
            x_sampled (torch.Tensor): The sampled prompt-response pairs (batch_size, input_length, n_sequences).

        Returns:
            A torch.Tensor containing the combined error for each sample in the batch (shape: (batch_size,)).
        """

        bs, _, _ = x_sampled.shape
        x_sampled = copy.deepcopy(x_sampled)

        x_sampled -= self.data_means
        x_sampled /= self.data_stds
        x_sampled = x_sampled.reshape(bs, -1)


        model_input = torch.cat([
            x_sampled.to(self.g_diff.device),
            self.g_diff.t2float(0).unsqueeze(0).repeat(bs, 1)
        ], dim=1).to(self.g_diff.device)

        eps = self.model(model_input).detach().cpu()
        return eps
