import torch.nn as nn
from torchtyping import TensorType
from typing import Union
import wandb

from ..configs import (
    NesimConfig,
    NeighbourhoodCosineSimilarity,
    LaplacianPyramid,
    Ring,
    Ring1D,
)
from ..utils.getting_modules import get_module_by_name
from .laplacian_pyramid.loss import LaplacianPyramidLoss
from .neighbourhood_cossim.loss import NeighbourhoodCosineSimilarityLoss
from .ring.loss import RingLoss, RingLoss1D
from .scale_scheduler import ExponentialDecayScale, ExponentialDecayScaleGenerator


## unified API for a single layer loss
class SingleLayerLossHandler:
    def __init__(
        self,
        single_layer_config: Union[
            NeighbourhoodCosineSimilarity, LaplacianPyramid, Ring
        ],
        model: nn.Module,
        device: str,
    ):
        layer = get_module_by_name(module=model, name=single_layer_config.layer_name)

        if isinstance(single_layer_config, LaplacianPyramid):
            self.layer_loss = LaplacianPyramidLoss(
                layer=layer,
                device=device,
                factor_h=single_layer_config.shrink_factor,
                factor_w=single_layer_config.shrink_factor,
            )
        elif isinstance(single_layer_config, NeighbourhoodCosineSimilarity):
            self.layer_loss = NeighbourhoodCosineSimilarityLoss(
                layer=layer,
                device=device,
            )
        elif isinstance(single_layer_config, Ring):
            self.layer_loss = RingLoss(
                layer=layer,
                radius_inner=single_layer_config.radius_inner,
                radius_outer=single_layer_config.radius_outer,
                device=device,
            )

        elif isinstance(single_layer_config, Ring1D):
            self.layer_loss = RingLoss1D(
                layer=layer,
                freq_inner=single_layer_config.freq_inner,
                freq_outer=single_layer_config.freq_outer,
            )
        else:
            raise TypeError(f"Invalid single_layer_config: {single_layer_config}")

        self.config = single_layer_config

        """
        if scale is not None:
            wandb log the latest loss
        if scale is None:
            compute and then log the loss
        """
        if self.config.scale is not None:
            self.latest_loss = None

        if isinstance(self.config.scale, ExponentialDecayScale):
            self.scale = ExponentialDecayScaleGenerator(config=self.config.scale)
        else:
            self.scale = self.config.scale

    def compute(self):
        if self.config.scale is not None:
            loss = self.layer_loss.get_loss()
            self.latest_loss = loss.item()

            if isinstance(self.scale, float):
                # raise AssertionError('I am here', self.scale, loss)
                return self.scale * loss
            else:
                # print(f'scale: {self.scale.get_value(step = False)}')
                return self.scale.get_value(step=True) * loss
        else:
            return None

    def __repr__(self):
        return f"SingleLayerLossHandler with loss: {self.layer_loss} on layer: {self.config.layer_name}"

    def get_log_data(self):
        data = {}
        if self.config.scale is not None:
            assert (
                self.latest_loss is not None
            ), "Cannot wandb log the loss before its computed at least once. Run self.compute()"
            data[self.config.layer_name] = self.latest_loss
        else:
            data[self.config.layer_name] = self.layer_loss.get_loss().item()

        return data


class NesimLoss:
    def __init__(self, model: nn.Module, config: NesimConfig, device: str):
        self.model = model
        self.config = config
        self.device = device

        self.layer_handlers = []

        for single_layer_config in self.config.layer_wise_configs:
            handler = SingleLayerLossHandler(
                single_layer_config=single_layer_config,
                device=self.device,
                model=self.model,
            )
            self.layer_handlers.append(handler)

    def __repr__(self):
        message = f"NesimLoss: {self.layer_handlers}"
        return message

    def compute(self, reduce_mean=True) -> Union[TensorType, dict]:
        losses_from_each_layer = {}

        for loss_handler in self.layer_handlers:
            loss = loss_handler.compute()

            if loss is not None:
                losses_from_each_layer[loss_handler.config.layer_name] = loss
            else:
                pass

        if reduce_mean is True:
            if len(losses_from_each_layer) > 0:
                return sum(losses_from_each_layer.values()) / len(
                    losses_from_each_layer
                )
            else:
                return None
        else:
            return losses_from_each_layer
        
    def get_log_data(self) -> dict:
        all_data = {}

        for loss_handler in self.layer_handlers:
            if isinstance(loss_handler.layer_loss, LaplacianPyramidLoss):
                prefix = "laplacian_loss_"

            elif isinstance(loss_handler.layer_loss, NeighbourhoodCosineSimilarityLoss):
                prefix = "neighbourhood_cossim_loss_"
            elif isinstance(loss_handler.layer_loss, RingLoss):
                prefix = "ring_loss_"
            elif isinstance(loss_handler.layer_loss, RingLoss1D):
                prefix = "ring_loss_1d_"

            all_data[
                prefix + loss_handler.config.layer_name
            ] = loss_handler.get_log_data()[loss_handler.config.layer_name]

        return all_data

    def wandb_log(self):
        all_data = {}

        all_neighbourhood_cossim_losses = []
        all_ring_1d_losses = []
        all_dynamic_loss_scales = []

        for loss_handler in self.layer_handlers:
            if isinstance(loss_handler.layer_loss, LaplacianPyramidLoss):
                prefix = "laplacian_loss_"

                if not isinstance(loss_handler.scale, float) is not None:
                    all_dynamic_loss_scales.append(
                        loss_handler.scale.get_value(step=False)
                    )

            elif isinstance(loss_handler.layer_loss, NeighbourhoodCosineSimilarityLoss):
                prefix = "neighbourhood_cossim_loss_"

                all_neighbourhood_cossim_losses.append(
                    loss_handler.layer_loss.get_loss().item()
                )

            elif isinstance(loss_handler.layer_loss, RingLoss):
                prefix = "ring_loss_"
            elif isinstance(loss_handler.layer_loss, RingLoss1D):
                prefix = "ring_loss_1d_"
                all_ring_1d_losses.append(loss_handler.layer_loss.get_loss().item())

            all_data[
                prefix + loss_handler.config.layer_name
            ] = loss_handler.get_log_data()

        if len(all_neighbourhood_cossim_losses) > 0:
            wandb.log(
                {
                    "mean_neighbourhood_cosine_similarity": sum(
                        all_neighbourhood_cossim_losses
                    )
                    / len(all_neighbourhood_cossim_losses)
                }
            )

        if len(all_ring_1d_losses) > 0:
            wandb.log(
                {"mean_ring_1d_loss": sum(all_ring_1d_losses) / len(all_ring_1d_losses)}
            )

        if len(all_dynamic_loss_scales) > 0:
            wandb.log(
                {
                    f"mean_loss_scale": sum(all_dynamic_loss_scales)
                    / len(all_dynamic_loss_scales)
                }
            )

        wandb.log(all_data)

    def get_all_grid_states(self):
        all_data = {}
        for loss_handler in self.layer_handlers:
            if isinstance(loss_handler.layer_loss, LaplacianPyramidLoss):
                prefix = "laplacian_loss_"
            elif isinstance(loss_handler.layer_loss, NeighbourhoodCosineSimilarityLoss):
                prefix = "neighbourhood_cossim_loss_"

            all_data[
                prefix + loss_handler.config.layer_name
            ] = loss_handler.layer_loss.grid_container.grid

        return all_data
