from dataclasses import dataclass
from typing import Dict, Any, Tuple
import wandb
import torch
import torch.nn.functional as F
import numpy as np

from ddlm.modeling.diffusion import DiffusionOutput


class Strategy:
    def get_initial_state(self, n_x) -> Dict[str, Any]:
        pass

    def new_step(
        self, outputs: DiffusionOutput, state: Dict[str, Any]
    ) -> Dict[str, Any]:
        pass


class NoStrategy:
    @staticmethod
    def get_initial_state(n_x: torch.Tensor):
        state = {
            "exit_mask": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.bool),
        }
        return state

    @staticmethod
    def new_step(
            outputs: DiffusionOutput, state: Dict[str, Any]
    ) -> Dict[str, Any]:
        return state

@dataclass
class PatienceStrategy:
    patience: int = 2
    @staticmethod
    def get_initial_state(n_x: torch.Tensor) -> Dict[str, Any]:
        state = {
            "patience": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.long),
            "exit_mask": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.bool),
            "previous_logits": None,
        }
        return state


    def new_step(
        self, outputs: DiffusionOutput, state: Dict[str, Any]
    ) -> Dict[str, Any]:
        new_exit_mask = torch.zeros_like(state["exit_mask"], dtype=torch.bool)[
            ~state["exit_mask"]
        ]
        c_patience = state["patience"][~state["exit_mask"]]
        if not (state["previous_logits"] is None):
            for i, (prev_logits, new_logits) in enumerate(
                zip(state["previous_logits"][~state["exit_mask"]], outputs.logits)
            ):
                if torch.equal(prev_logits.argmax(-1), new_logits.argmax(-1)):
                    c_patience[i] += 1
                    if c_patience[i] > self.patience:
                        new_exit_mask[i] = True
                else:
                    c_patience[i] = 0
        else:
            state["previous_logits"] = outputs.logits.detach().clone()
        state["patience"][~state["exit_mask"]] = c_patience.clone()
        state["previous_logits"][~state["exit_mask"]] = outputs.logits.detach().clone()
        state["exit_mask"][~state["exit_mask"]] = new_exit_mask
        return state

@dataclass
class EntropyStrategy:
    threshold: float = 0.1
    min_step: int = 0
    @staticmethod
    def get_initial_state(n_x: torch.Tensor) -> Dict[str, Any]:
        state = {
            "exit_mask": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.bool),
            "step": 0,
        }
        return state

    @staticmethod
    def _count_entropy(logits_batch):
        log_p_x = torch.nn.functional.log_softmax(logits_batch, dim=-1)
        p_x = torch.nn.functional.softmax(logits_batch, dim=-1)
        entropy = torch.mean(torch.sum(-p_x * log_p_x, dim=-1), dim=1)
        return entropy

    def new_step(
        self, outputs: DiffusionOutput, state: Dict[str, Any]
    ) -> Dict[str, Any]:
        new_exit_mask = torch.zeros_like(state["exit_mask"], dtype=torch.bool)[
            ~state["exit_mask"]
        ]
        entropy = self._count_entropy(outputs.logits)

        for i, e in enumerate(entropy):
            if e <= (self.threshold / 64.) and state["step"] > self.min_step:
                new_exit_mask[i] = True
        state["exit_mask"][~state["exit_mask"]] = new_exit_mask
        state["step"] += 1
        return state
    

@dataclass
class FixedStrategy:
    threshold: int = 0.1
    @staticmethod
    def get_initial_state(n_x) -> Dict[str, Any]:
        state = {
            "step": 0,
            "exit_mask": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.bool),
        }
        return state

    def new_step(
        self, outputs: DiffusionOutput, state: Dict[str, Any]
    ) -> Dict[str, Any]:
        
        state["step"] += 1
        
        if state["step"] == self.threshold:
            state["exit_mask"] = torch.ones_like(state["exit_mask"])

        return state

@dataclass
class KLDivStrategy:
    threshold: float = 0.1
    min_step: int = 20

    @staticmethod
    def get_initial_state(n_x) -> Dict[str, Any]:
        state = {
            "step": 0,
            "exit_mask": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.bool),
            "previous_logits": None,
        }
        return state
    
    @staticmethod
    def _count_kl_divergence(current, previous):
        kldiv = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
        logsoftmax = torch.nn.LogSoftmax(dim=1)
        current = logsoftmax(current)
        previous = logsoftmax(previous)
        current_previous = kldiv(current, previous).tolist()
        previous_current = kldiv(previous, current).tolist()
        
        return current_previous + previous_current
        

    def new_step(
        self, outputs: DiffusionOutput, state: Dict[str, Any]
    ) -> Dict[str, Any]:
        
        new_exit_mask = torch.zeros_like(state["exit_mask"], dtype=torch.bool)[
            ~state["exit_mask"]
        ]

        if (not (state["previous_logits"] is None)) and state["step"] > self.min_step:
            for i, (prev_logits, new_logits) in enumerate(
                zip(outputs.logits, state["previous_logits"][~state["exit_mask"]])
            ):
                if self._count_kl_divergence(new_logits, prev_logits) < self.threshold:
                    new_exit_mask[i] = True
        else:
            state["previous_logits"] = outputs.logits.detach().clone()

        state["previous_logits"][~state["exit_mask"]] = outputs.logits.detach().clone()
        state["exit_mask"][~state["exit_mask"]] = new_exit_mask
        state["step"] += 1
        return state

class LogStrategy:
    @staticmethod
    def get_initial_state(n_x: torch.Tensor) -> Dict[str, Any]:
        state = {
            "patience": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.long),
            "exit_mask": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.bool),
            "previous_logits": None,
            "entropy": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.long),
            "kl": torch.zeros(n_x.size(0), device=n_x.device, dtype=torch.long),
        }
        return state

    @staticmethod
    def _count_entropy(logits_batch):
        log_p_x = torch.nn.functional.log_softmax(logits_batch, dim=-1)
        p_x = torch.nn.functional.softmax(logits_batch, dim=-1)
        entropy = torch.mean(torch.sum(-p_x * log_p_x, dim=-1), dim=1)
        return entropy
    
    @staticmethod
    def _count_patience(outputs, state):
        c_patience = state["patience"]

        for i, (prev_logits, new_logits) in enumerate(
            zip(state["previous_logits"], outputs.logits)
        ):
            if torch.equal(prev_logits.argmax(-1), new_logits.argmax(-1)):
                c_patience[i] += 1
            else:
                c_patience[i] = 0
    
        return c_patience
    
    @staticmethod
    def _count_kl_divergence(outputs, state):
        current_previous = []
        previous_current = []
        kldiv = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
        logsoftmax = torch.nn.LogSoftmax(dim=1)

        for current, previous in zip(outputs.logits, state["previous_logits"]):
            current = logsoftmax(current)
            previous = logsoftmax(previous)
            current_previous.append(kldiv(current, previous).tolist())
            previous_current.append(kldiv(previous, current).tolist())

        return {
            "current_previous": current_previous,
            "previous_current": current_previous,
            "double": [i + j for i, j in zip(current_previous, previous_current)]
        }

    def new_step(
        self, outputs: DiffusionOutput, state: Dict[str, Any]
    ) -> Dict[str, Any]:
        
        if not (state["previous_logits"] is None):
            state["patience"] = self._count_patience(outputs, state).clone()
        else:
            state["previous_logits"] = outputs.logits.detach().clone()

        kl_div = self._count_kl_divergence(outputs, state)
        state["kl"] = kl_div["double"]
        state["entropy"] = self._count_entropy(outputs.logits)
        state["previous_logits"] = outputs.logits.detach().clone()

        return state
    