from typing import Any, Dict

import numpy as np
import torch

from .base import Algorithm

IGNORE_INDEX = -100


class TrainPredictor(Algorithm):
    """
    BC Implementation.
    Uses MSE loss for continuous, and CE for discrete
    """

    def __init__(self, *args, smoothing_coeff=1.0, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        assert "predictor" in self.network.CONTAINERS
        self.smoothing_coeff = smoothing_coeff

    def setup_optimizers(self) -> None:
        """
        Decay support added explicitly. Maybe move this to base implementation?
        """
        # create optim groups. Any parameters that is 2D or higher will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_group = {"params": [p for p in self.network.predictor.parameters() if p.dim() >= 2 and p.requires_grad]}
        no_decay_group = {"params": [p for p in self.network.predictor.parameters() if p.dim() < 2 and p.requires_grad]}
        decay_group.update(self.optim_kwargs)
        no_decay_group.update(self.optim_kwargs)
        no_decay_group["weight_decay"] = 0.0
        self.optim["predictor"] = self.optim_class((decay_group, no_decay_group))

    def _get_predictor_losses(self, feedback_batch, replay_batch):
        obs = torch.cat((feedback_batch["obs_1"], feedback_batch["obs_2"]), dim=0)
        action = torch.cat((feedback_batch["action_1"], feedback_batch["action_2"]), dim=0)

        pred_1, pred_2 = self.network.predictor(obs, action).mean(dim=-1).chunk(2, dim=0)
        logits = pred_2 - pred_1
        predictor_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, feedback_batch["label"].float(), reduction="mean"
        )

        # Get the predictor accuracy
        with torch.no_grad():
            accuracy = ((logits > 0) == torch.round(feedback_batch["label"])).float().mean()

        if self.smoothing_coeff > 0.0 and replay_batch is not None:
            seg_len = action.shape[1]
            smooth_obs, smooth_action = replay_batch["obs"], replay_batch["action"]

            # Sample a delta
            delta = np.random.randint(1, replay_batch["action"].shape[1] - seg_len)
            smooth_obs = torch.cat((smooth_obs[:, :seg_len], smooth_obs[:, delta : delta + seg_len]), dim=0)
            smooth_action = torch.cat((smooth_action[:, :seg_len], smooth_action[:, delta : delta + seg_len]), dim=0)

            smooth_pred_1, smooth_pred_2 = (
                self.network.predictor(smooth_obs, smooth_action).mean(dim=-1).chunk(2, dim=0)
            )
            smoothing_loss = torch.square(smooth_pred_2 - smooth_pred_1).mean()

            # The "Correct way" of doing this.
            # smooth_logits = smooth_pred_2 - smooth_pred_1
            # smoothing_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            #     smooth_logits, 0.5 * torch.ones_like(smooth_logits), reduction="mean"
            # )

        else:
            smoothing_loss = 0.0

        return predictor_loss, smoothing_loss, accuracy

    def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict:
        if isinstance(batch, (tuple, list)):
            replay_batch, feedback_batch = batch
        else:
            replay_batch, feedback_batch = None, batch

        predictor_loss, smoothing_loss, accuracy = self._get_predictor_losses(feedback_batch, replay_batch)

        self.optim["predictor"].zero_grad(set_to_none=True)
        (predictor_loss + self.smoothing_coeff * smoothing_loss).backward()
        self.optim["predictor"].step()

        return dict(
            predictor_loss=predictor_loss.item(),
            smoothing_loss=0 if isinstance(smoothing_loss, float) else smoothing_loss.item(),
            accuracy=accuracy.item(),
        )

    def validation_step(self, batch: Any) -> Dict:
        if isinstance(batch, (tuple, list)):
            replay_batch, feedback_batch = batch
        else:
            replay_batch, feedback_batch = None, batch

        with torch.no_grad():
            predictor_loss, smoothing_loss, accuracy = self._get_predictor_losses(feedback_batch, replay_batch)

        return dict(
            predictor_loss=predictor_loss.item(),
            smoothing_loss=0 if isinstance(smoothing_loss, float) else smoothing_loss.item(),
            accuracy=accuracy.item(),
        )
