from __future__ import annotations

import torch
import torch.optim as optim
import tqdm
import wandb

from .critic import Critic
from .encoder import Encoder
from .generator import Generator
from .critic_image import Critic_image
from .snapshot import SnapshotSaver

class FeatureInversionPipeline:
    """Feature inversion pipeline

    Parameters
    ----------
    generator : Generator
        Generator module that generates images.
    encoder : Encoder
        Encoder module that extracts layer-wise features from images.
    critic : Critic
        Critic module that computes the loss between generated features and target
        features.
    optimizer : optim.Optimizer
        Optimizer module that optimizes the generator.
    critic_image : Critic_image, optional
        Critic_image module that computes the loss on the image space
    critic_style : Critic_style, optional
        Critic_style module that computes the style loss
    scheduler : optim.lr_scheduler.LRScheduler, optional
        Scheduler module that schedules the learning rate of the optimizer, by default
        None.
    num_iterations : int, optional
        Number of iterations, by default 1.
    log_interval : int, optional
        Interval of logging, by default -1. If -1, logging is disabled.
    with_wandb : bool, optional
        Whether to use wandb, by default False.
    snapshot_saver : SnapshotSaver, optional
        SnapshotSaver module that saves snapshots, by default None.
    eval_metric: callable, optional
        Evaluation metric for feature, by default None
    eval_metrics: callable | list[callable], optional
        Evaluation metrics for feature, by default None
    pixel_eval_metrics: callable | list[callable], optional
        Evaluation metrics for pixel, by default None
    stop_criteria: callable, optional
        Stop criteria, by default None
        Receives the current metrics and returns True if the optimization should stop
    eval_interval: int, optional
        Interval of evaluation, by default 1.
    """

    def __init__(
        self,
        generator: Generator,
        encoder: Encoder,
        critic: Critic,
        optimizer: optim.Optimizer,
        num_iterations: int,
        critic_image: Critic_image | None = None,
        scheduler: optim.lr_scheduler.LRScheduler | None = None,
        log_interval: int = -1,
        with_wandb: bool = False,
        snapshot_saver: SnapshotSaver | None = None,
        eval_metric: callable | None = None,
        eval_metrics: callable | list[callable] | None = None,
        pixel_eval_metrics: callable | list[callable] | None = None,
        stop_criteria: callable | None = None,
        wandb_log_interval: int = 1,
        record_grad_norm: bool = False,
        gradient_clipping: bool = False,
        eval_interval: int = 1,
    ) -> None:
        super().__init__()
        self.generator = generator
        self.encoder = encoder
        self.critic = critic
        self.optimizer = optimizer
        self.critic_image = critic_image
        self.scheduler = scheduler
        self.num_iterations = num_iterations
        self.log_interval = log_interval
        self.with_wandb = with_wandb
        self.snapshot_saver = snapshot_saver
        self.stop_criteria = stop_criteria
        self.wandb_log_interval = wandb_log_interval
        self.record_grad_norm = record_grad_norm
        self.gradient_clipping = gradient_clipping
        self.eval_interval = eval_interval

        if eval_metrics is not None:
            if not isinstance(eval_metrics, list):
                eval_metrics = [eval_metrics]
        if eval_metric is not None:
            # warning
            print('Warning: use eval_metrics instead of eval_metric')
            eval_metrics = [eval_metric] + eval_metrics if eval_metrics is not None else [eval_metric]
        self.eval_metrics = eval_metrics

        if pixel_eval_metrics is not None:
            if not isinstance(pixel_eval_metrics, list):
                pixel_eval_metrics = [pixel_eval_metrics]
        self.pixel_eval_metrics = pixel_eval_metrics

        if self.with_wandb:
            self.critic.enable_wandb()
            if self.critic_image is not None:
                self.critic_image.enable_wandb()

    def __call__(
            self, 
            target_features: dict[str, torch.Tensor], 
            gram_matrices: dict[str, torch.Tensor] | None = None,
            wandb_names: list[str] | None = None
            ) -> torch.Tensor:
        """Forward pass through the iCNN pipeline.

        Parameters
        ----------
        target_features : dict[str, torch.Tensor]
            Target features indexed by the layer names.
        stimulus_names : List[str], optional
            Names of the stimuli, by default None.
        wandb_names: list[str] | None
            Names of the stimuli for wandb logging, by default None.

        Returns
        -------
        torch.Tensor
            Generated images.
        """
        history = []  # list[dict[str, list[float]]]
        pbar = tqdm.tqdm(range(self.num_iterations), dynamic_ncols=True)
        for step in pbar:
            self.optimizer.zero_grad()
            generated_images = self.generator()
            gen_images = generated_images.clone().detach()  # for evaluation
            generated_features = self.encoder(generated_images)
            loss = self.critic(generated_features, target_features)
            if self.critic_image is not None:
                loss += self.critic_image(generated_images)
            loss.sum().backward()

            # record the gradient norm for each sample
            if self.record_grad_norm:
                grad_norm = self.generator.grad_norm()

            # gradient clipping
            # this applies the gradient clipping to all samples at once
            if self.gradient_clipping:
                torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)

            self.optimizer.step()
            if self.scheduler is not None:
                self.scheduler.step()
            
            # Evaluation of the features and imagescritic
            # metric_name -> list[float], element i corresponds to the i-th stimulus
            results = {'loss': loss.tolist()}
            if self.record_grad_norm:
                results['grad_norm'] = grad_norm
            if step % self.eval_interval == 0 or step == self.num_iterations - 1:
                # evaluate the features and images: always evaluate at the last step
                with torch.no_grad():
                    results.update(self.evaluate_features(generated_features, target_features))
                    results.update(self.evaluate_image(gen_images))

            # store the results
            history.append(results)

            if self.with_wandb and self.wandb_log_interval > 0 and step % self.wandb_log_interval == 0:
                self.log_wandb(step, results, wandb_names)

            if self.log_interval > 0 and step % self.log_interval == 0:
                print(f"Step [{step+1}/{self.num_iterations}]: loss={loss.mean().item():.4f}")

            if self.snapshot_saver is not None:
                self.snapshot_saver(step, generated_images.detach(), results)

            if self.stop_criteria is not None and self.stop_criteria(results):
                print('Stopping criteria met')
                break

        self.history = history
        return self.generator().detach()
    
    def log_wandb(self, step: int, results: dict, wandb_names: list[str] | None = None) -> None:
        """
        Report evaluation metrics to wandb.

        Args:
            results (dict): Evaluation results. metric_name -> list[float]
            wandb_names (list[str] | None): Names of the stimuli for wandb logging, by default None.
        """
        if wandb_names is None:
            wandb_names = [f"stimulus_{i}" for i in range(len(results['loss']))]

        # aggregate the results by stimulus
        log = {}  # name -> metric -> value
        for i, name in enumerate(wandb_names):
            log[name] = {}
            for metric_name, values in results.items():
                log[name][metric_name] = values[i]
        wandb.log(log, step=step)
    
    def evaluate_features(self, generated_features, target_features):
        eval_results = {}  # metric_name -> list[float]
        if self.eval_metrics is None:
            return eval_results
        for metric in self.eval_metrics:
            r = metric(generated_features, target_features)
            eval_results.update(r)
        return eval_results
    
    def evaluate_image(self, generated_images):
        eval_results = {}  # metric_name -> list[float]
        if self.pixel_eval_metrics is None:
            return eval_results
        for metric in self.pixel_eval_metrics:
            r = metric(generated_images)
            eval_results.update(r)
        return eval_results

    def reset_states(self, gen_seeds=None) -> None:
        """Reset the state of the pipeline.

        Args:
            gen_seeds (list[int] | None, optional): Random seeds for the generator.

        Notes
        -----
        This method is needed to reset the state of the optimizer and the generator
        when the pipeline is used for multiple stimuli. Otherwise, the initial
        generated image for the second stimulus is the final generated image for the
        first stimulus. Other implementaion idea is to put the optimizer and the
        generator in the __call__ method, instead of the __init__ method.
        """

        self.generator.reset_states(gen_seeds)
        self.optimizer = self.optimizer.__class__(
            self.generator.parameters(), **self.optimizer.defaults
        )
        # some snapshot saver has states
        if self.snapshot_saver is not None:
            self.snapshot_saver.reset_states()
