from collections.abc import Iterable
from pathlib import Path
from typing import Tuple

import numpy as np
import plotly.graph_objects as go
import torch

from src.config import Config


class OneDFixedInferencer:
    def __init__(self, model: torch.nn.Module, dl: torch.utils.data.Subset, config: Config):
        self.model = model
        self.dl = dl
        self.config = config
        self.model.eval()

        if hasattr(self.config.training, 'inference_params') and \
                isinstance(self.config.training.inference_params, Iterable) and \
                'show_absolute_error' in self.config.training.inference_params:
            self.show_abs_label = self.config.training.inference_params['show_absolute_error']
        else:
            self.show_abs_label = False

        self._visualization_data = None

    @property
    def visualization_data(self) -> dict:
        if self._visualization_data is None:
            np.random.seed(42)
            len_idxs = min(len(self.dl.dataset), 4)
            idxs = np.random.choice(len(self.dl.dataset), size=len_idxs, replace=False)
            self._visualization_data = {key: torch.stack([self.dl.dataset[i][key] for i in idxs])
                                        for key in self.dl.dataset[0].keys()}
            self._visualization_data = {key: val.to(self.config.training.device) for key, val in
                                        self._visualization_data.items()}
        return self._visualization_data

    def visualize(self, path: Path, n_samples: int = 4) -> Tuple[list, list]:
        """Visualize a sample from the dataset by generating HTML line plots for predictions vs targets.

        Parameters
        ----------
        path : Path
            Directory to save the HTML files.
        data_loader : torch.utils.data.DataLoader
            DataLoader for the dataset.
        n_samples : int
            Number of samples to visualize.
        seed : int
            Random seed for reproducibility.
        """

        output_data = []
        with torch.no_grad():
            y_pred = self.model(self.visualization_data).cpu().numpy().squeeze()
            y_true = self.visualization_data['y'].cpu().numpy().squeeze()
            y_true = np.abs(y_true) if self.show_abs_label else y_true
            for i in range(len(y_true)):
                x = np.arange(y_true[i].shape[-1])

                fig = go.Figure()
                # Predicted mean
                fig.add_trace(go.Scatter(
                    x=x, y=y_pred[i], mode='lines', name='Prediction',
                    line=dict(color='blue')
                ))

                # Target
                fig.add_trace(go.Scatter(
                    x=x, y=y_true[i], mode='lines', name='Target',
                    line=dict(color='green')
                ))

                html_path = path / f"sample_{i}.html"
                fig.write_html(html_path)
                output_data.append(html_path)

        return output_data, ["plotly_line_plot"] * n_samples
