from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, List, Tuple, Optional, Union
import torch
from torch import nn, Tensor

from inference_models import InferenceModelBase, ST, InferenceResult
from experiments.leaves.fit_dirichlet_leaves import fit_dirichlet
from torch.distributions import Categorical, Dirichlet

EPS = 1e-8

@dataclass
class TrainResult:
    percept_loss: torch.Tensor
    q_loss: torch.Tensor
    entropy: Optional[torch.Tensor]
    P: Tensor

class ANeSIBase(ABC, Generic[ST], nn.Module):

    def __init__(self,
                 q: InferenceModelBase[ST],
                 perception: nn.Module,
                 amount_samples: int,
                 belief_size: List[int],
                 dirichlet_init: float = 0.01,
                 dirichlet_iters: int = 50,
                 dirichlet_lr: float = 1.0,
                 dirichlet_L2: float = 0.0,
                 fixed_alpha: Optional[Union[float, Tensor]] = None,
                 K_beliefs: int = 100,
                 predict_only: bool = False,
                 P_source: str = "prior",
                 q_lr=1e-3,
                 q_loss="bce",
                 policy: str = "off",
                 perception_lr=1e-3,
                 perception_loss='log-q',
                 percept_loss_pref=1.0,
                 entropy_weight: Optional[float] = None,
                 **kwargs
                 ):
        """
        :param q: The inference model
        :param perception: The perception network. Should accept samples from data
        :param amount_samples: The amount of samples to draw to train the inference model
        :param initial_concentration: The initial concentration of the Dirichlet distribution
        :param K_beliefs: The amount of beliefs to keep to fit the Dirichlet
        :param percept_loss_pref: When using perception_loss='both', this will prefer log-q if > 1.0, otherwise sampled
        :param predict_only: If True, assume only the prediction model is used, not the explain model
        :param P_source: If 'prior', use the prior to sample P. If 'percept', use perceptions of P. If 'both', use both. 'percept' is not recommended and is unlikey to converge
        :param q_lr: The learning rate of the inference model
        :param q_loss: The loss function of the inference model. Either 'mse' or 'bce'
        :param policy: The policy to use. Either 'off', 'on' or 'both'. 'on' requires predict_only=False
        """
        super().__init__()
        assert perception_loss in ['sampled', 'log-q', 'both', 'none']
        assert q_loss in ['bce'] # mse
        assert policy in ['both', 'off', 'on']
        assert P_source in ['prior', 'percept', 'both']
        # prediction only option is only supported for off-policy learning of nrm
        assert not (predict_only and policy in ['on', 'both'])
        # Sampled loss requires explain model
        assert not (predict_only and perception_loss in ['sampled', 'both'])

        self.predict_only = predict_only
        self.P_source = P_source
        self.q = q
        self.perception = perception
        self.amount_samples = amount_samples
        self.belief_size = belief_size
        self.K_beliefs = K_beliefs
        self.dirichlet_iters = dirichlet_iters
        self.dirichlet_lr = dirichlet_lr
        self.dirichlet_L2 = dirichlet_L2
        self.fixed_alpha = None
        self.q_loss = q_loss
        self.policy = policy
        self.perception_loss = perception_loss
        self.percept_loss_pref = percept_loss_pref
        self.entropy_weight = entropy_weight

        # We're training these two models separately, so let's also use two different optimizers.
        #  This ensures we won't accidentally update the wrong model.
        self.q_optimizer = torch.optim.Adam(self.q.parameters(), lr=q_lr)
        if self.perception is not None:
            self.perception_optimizer = torch.optim.Adam(self.perception.parameters(), lr=perception_lr)

        if fixed_alpha is not None:
            if isinstance(fixed_alpha, float):
                alpha1 = nn.Parameter(torch.full((belief_size[0],), fixed_alpha, requires_grad=True))
                alpha2 = nn.Parameter(torch.full((belief_size[1],), fixed_alpha, requires_grad=True))
                alpha3 = nn.Parameter(torch.full((belief_size[2],), fixed_alpha, requires_grad=True))
            else:
                alpha1 = nn.Parameter(fixed_alpha[0])
                alpha2 = nn.Parameter(fixed_alpha[1])
                alpha3 = nn.Parameter(fixed_alpha[2])
            self.fixed_alpha = nn.ParameterList((alpha1, alpha2, alpha3))
        else:
            _x = torch.tensor(dirichlet_init)
            t_initial_concentration = _x + torch.log(-torch.expm1(-_x))
            alpha1 = nn.Parameter(torch.full((belief_size[0],), t_initial_concentration, requires_grad=True))
            alpha2 = nn.Parameter(torch.full((belief_size[1],), t_initial_concentration, requires_grad=True))
            alpha3 = nn.Parameter(torch.full((belief_size[2],), t_initial_concentration, requires_grad=True))
            self.alpha = nn.ParameterList((alpha1, alpha2, alpha3))
            self.dirichlet_optimizer = torch.optim.Adam(self.alpha, lr=dirichlet_lr)
        self.beliefs1 = None
        self.beliefs2 = None
        self.beliefs3 = None

    def joint_matching_loss(self, P1: torch.Tensor, P2, P3, q: [torch.Tensor], y: torch.Tensor, w: torch.Tensor,
                            predict_only: bool=False, assume_w_correct: bool=False, punish_incorrect_weight: Optional[float] = 1000.):
        log_q = 0.
        # TODO: This condition looks hacky
        if all([r.shape[-1] == q[0].shape[-1] for r in q]):
            log_q = (torch.stack(q, -1) + EPS).log().sum(tuple(range(1, len(q[0].shape) + 1)))
        else:
            for prob in q:
                if len(prob.shape) == 2:
                    log_q += (prob + EPS).log().sum(-1)
                else:
                    log_q += (prob + EPS).log()

        if predict_only:
            # KL div loss. Increases probability of observed ys
            return -log_q.mean()

        # Joint matching loss
        # (batch,)
        if w.shape[-1] == 1:
            w = w.squeeze(-1)
        log_p1 = Categorical(probs=P1).log_prob(w[:,0])
        log_p2 = Categorical(probs=P2).log_prob(w[:,1])
        log_p3 = Categorical(probs=P3).log_prob(w[:,2])
        log_p = log_p1 + log_p2 + log_p3
        if not assume_w_correct:
            log_p += (1 - (self.symbolic_function(w) == y).min(-1)[0].float()) * -punish_incorrect_weight
        if self.q_loss == 'mse':
            return (log_q - log_p).pow(2).mean()
        elif self.q_loss == 'bce':
            return nn.BCELoss()(log_q.exp(), log_p.exp())
        raise NotImplementedError()

    def off_policy_loss(self, p_P1: Dirichlet, p_P2: Dirichlet, p_P3: Dirichlet, batch_P1: torch.Tensor, batch_P2, batch_P3) -> torch.Tensor:
        """
        Algorithm 2
        For now assumes all w_i are the same size
        """
        use_prior = self.P_source != 'percept'
        sample_count = batch_P1.shape[0] * self.amount_samples
        if use_prior:
            P1 = p_P1.sample((sample_count,))
            P2 = p_P2.sample((sample_count,))
            P3 = p_P3.sample((sample_count,))
            if self.P_source == 'both':
                P1 = torch.cat([P1, batch_P1], 0)
                P2 = torch.cat([P2, batch_P2], 0)
                P3 = torch.cat([P3, batch_P3], 0)
        else:
            P1 = batch_P1
            P2 = batch_P2
            P3 = batch_P3
        p_w1 = Categorical(probs=P1)
        p_w2 = Categorical(probs=P2)
        p_w3 = Categorical(probs=P3)
        # p_w = Categorical(probs=P)
        # (batch, |W|)

        w1 = p_w1.sample()
        w2 = p_w2.sample()
        w3 = p_w3.sample()
        w = torch.stack((w1, w2, w3), dim=1)

        # (batch,)
        y = self.symbolic_function(w)

        initial_state = self.initial_state(P1, P2, P3, y, w, generate_w=not self.predict_only)

        amt_samples =  w.shape[0] if use_prior else 1
        result = self.q.forward(initial_state, amt_samples=amt_samples)
        return self.joint_matching_loss(P1, P2, P3, result.forward_probabilities, y, w,
                                        predict_only=self.predict_only, assume_w_correct=True)

    def log_q_loss(self, P1: torch.Tensor, P2, P3, y: torch.Tensor):
        """
        Perception loss that maximizes the log probability of the label under the NRM model.
        """
        initial_state = self.initial_state(P1, P2, P3, y, generate_w=False)
        result = self.q.forward(initial_state)
        stack_ys = (torch.stack(result.forward_probabilities, -1) + EPS).log()
        log_q_y = stack_ys.sum(-1).mean()
        return -log_q_y

    def train_all(self, x: torch.Tensor, y: torch.Tensor) -> TrainResult:
        """
        Algorithm 3
        - Calls algorithm 2 (nrm_loss)
        - Minimizes nrm loss using first optimizer
        - Computes q(y|P)
        - Maximize log prob
        """
        use_off_policy_loss = not self.policy == 'on'
        self.q_optimizer.zero_grad()

        q_loss = 0.
        P1, P2, P3 = self.perception(x)
        if use_off_policy_loss or self.perception_loss == 'both':
            if self.beliefs1 is None:
                self.beliefs1 = P1
                self.beliefs2 = P2
                self.beliefs3 = P3
            else:
                self.beliefs1 = torch.cat((self.beliefs1, P1), dim=0)
                self.beliefs2 = torch.cat((self.beliefs2, P2), dim=0)
                self.beliefs3 = torch.cat((self.beliefs3, P3), dim=0)
                if self.beliefs1.shape[0] > self.K_beliefs:
                    self.beliefs1 = self.beliefs1[-self.K_beliefs:]
                    self.beliefs2 = self.beliefs2[-self.K_beliefs:]
                    self.beliefs3 = self.beliefs3[-self.K_beliefs:]

            beliefs1 = self.beliefs1.detach()
            beliefs2 = self.beliefs2.detach()
            beliefs3 = self.beliefs3.detach()
            if self.fixed_alpha is not None:
                p_P1, p_P2, p_P3 = Dirichlet(self.alpha[0]), Dirichlet(self.alpha[1]), Dirichlet(self.alpha[2])
            else:
                p_P1, p_P2, p_P3 = fit_dirichlet([beliefs1, beliefs2, beliefs3], self.alpha, self.dirichlet_optimizer, self.dirichlet_iters, self.dirichlet_L2)

            if use_off_policy_loss:
                _q_loss = self.off_policy_loss(p_P1,p_P2, p_P3, P1.detach(), P2.detach(), P3.detach())
                _q_loss.backward()
                q_loss += _q_loss
                self.q_optimizer.step()
                self.q_optimizer.zero_grad()

        use_sampled_loss = self.perception_loss in ['sampled', 'both']
        use_log_q_loss = self.perception_loss in ['log-q', 'both']
        use_on_policy_loss = not self.policy == 'off'

        loss_sampled = 0.
        loss_log_q = 0.

        self.perception_optimizer.zero_grad()
        if use_sampled_loss or use_on_policy_loss:
            _loss_sampled, _q_loss = self.sampled_loss(P1, P2, P3, y.long(), use_sampled_loss, use_on_policy_loss)
            if use_sampled_loss:
                loss_sampled = _loss_sampled
            if use_on_policy_loss:
                _q_loss.backward()
                q_loss += _q_loss
                self.q_optimizer.step()
        if use_log_q_loss:
            loss_log_q = self.log_q_loss(P1, P2, P3, y.long())

        loss_percept = torch.tensor(0., device=x.device)
        entropy = torch.tensor(0., device=x.device)
        if self.perception_loss != 'none':
            w_sampled = 1. if use_sampled_loss else 0.
            w_log_q = 1. if use_log_q_loss else 0.
            loss_percept = w_sampled * loss_sampled + w_log_q * loss_log_q

            # The entropy of the categorical distribution P
            entropy = -(torch.sum(P1 * torch.log(P1 + EPS), dim=-1) + torch.sum(P2 * torch.log(P2 + EPS), dim=-1) + torch.sum(P3 * torch.log(P3 + EPS), dim=-1))/3
            if self.entropy_weight:
                # This **maximizes** the entropy. Use a negative entropy weight to minimize
                (loss_percept - self.entropy_weight * entropy).backward()
            else:
                loss_percept.backward()
        self.perception_optimizer.step()

        return TrainResult(q_loss=q_loss, percept_loss=loss_percept, entropy=entropy, P=torch.cat((P1, P2, P3), dim=1))

    def test(self, x: torch.Tensor, y: torch.Tensor, true_w: Optional[List[torch.Tensor]] = None
             ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Algorithm 1
        Sample from the PPE model
        :param x: The data to perform beam search on
        :param y: The true label y to predict and compare to
        :param true_w: The true w (explanation to get to y). Used to evaluate if explanations are correct
        """
        P1, P2, P3 = self.perception(x)
        initial_state = self.initial_state(P1, P2, P3, generate_w=not self.predict_only)
        result: InferenceResult[ST] = self.q.beam(initial_state, beam_size=self.amount_samples*x.shape[0])
        successes = self.success(result.final_state.y, y, beam=True).float()

        p1 = torch.argmax(P1, dim=-1)
        p2 = torch.argmax(P2, dim=-1)
        p3 = torch.argmax(P3, dim=-1)
        prior_predictions = torch.stack((p1, p2, p3), dim=1)
        prior_y = self.symbolic_function(prior_predictions)

        successes_prior = (y == prior_y).float().mean()

        if true_w is not None:
            explain_acc = torch.tensor(0., device=successes.device)
            if not self.predict_only:
                for i in range(len(true_w)):
                    # Get beam search prediction of w, compare to ground truth w
                    explain_acc += (result.final_state.w[i][:, 0] == true_w[i]).float().mean()
                explain_acc /= len(true_w)

            prior_acc = (prior_predictions == torch.stack(true_w, 1)).float().mean()

            return torch.mean(successes), successes_prior, explain_acc, prior_acc
        return torch.mean(successes), successes_prior, torch.tensor(0), torch.tensor(0)

    @abstractmethod
    def initial_state(self, P: torch.Tensor, y: Optional[torch.Tensor] = None, w: Optional[torch.Tensor] = None,
                      generate_w=True) -> ST:
        assert not (y is None and w is not None)
        pass

    @abstractmethod
    def symbolic_function(self, w: torch.Tensor) -> torch.Tensor:
        pass

    @abstractmethod
    def preprocess_y(self, y: torch.Tensor) -> List[torch.Tensor]:
        pass

    @abstractmethod
    def success(self, prediction: List[torch.Tensor], y: torch.Tensor, beam: bool = False) -> torch.Tensor:
        """
        Returns the _probability_ of success. Should probably return the most likely result and compare this instead.
        # TODO: Use a beam search here somehow to parse y
        """
        pass

    def sampled_loss(self, P1: torch.Tensor, P2, P3, y: torch.Tensor, compute_perception_loss: bool, compute_q_loss: bool) \
            -> Tuple[torch.Tensor, torch.Tensor]:
        initial_state = self.initial_state(P1.detach(), P2.detach(), P3.detach(), y, generate_w=True)
        result = self.q.forward(initial_state, amt_samples=1)

        percept_loss = q_loss = 0.

        # Take sum of log probs over all dimensions
        log_p_w = result.final_state.log_p_world()

        if compute_perception_loss:
            # # TODO: This might underflow
            # q_y = torch.stack(
            #     result.forward_probabilities[:len(result.final_state.y)], 1
            # ).prod(-1).detach()
            prediction = self.symbolic_function(torch.stack(result.final_state.w, -1))
            # TODO: Not sure about the squeeze here
            # TODO: Also not sure about the min here...
            successes = torch.min(prediction.squeeze() == y, dim=-1)[0].float()
            # print(y)
            if self.q.prune and self.q.no_actions_behaviour == 'raise':
                # If we prune, we know we are successful by definition
                assert successes.all()
                percept_loss = -log_p_w.mean()
            else:
                percept_loss = -(log_p_w * successes).mean() if successes.any() else 0.

        if compute_q_loss:
            log_q_y = (torch.stack(result.forward_probabilities[:len(result.final_state.y)], 1) + EPS).log().sum(-1)
            log_q_y = log_q_y.unsqueeze(-1)
            log_q_w_y = (torch.stack(result.forward_probabilities[len(result.final_state.y):], -1) + EPS).log().sum(-1)
            log_q = log_q_y + log_q_w_y
            log_p_w = log_p_w.detach()
            if self.q_loss == 'mse':
                q_loss = (log_q - log_p_w).pow(2).mean()
            elif self.q_loss == 'bce':
                q_loss = nn.BCELoss()(log_q.exp(), log_p_w.exp())
            else:
                raise NotImplementedError()

        return percept_loss, q_loss
