from typing import Tuple
from expground.algorithms import misc
import torch
import numpy as np

from collections import namedtuple
from torch.distributions import Categorical
from torch import optim
from torch.functional import tensordot

from expground.types import PolicyID, Dict, List, Tuple
from expground.logger import Log
from expground.utils import data
from expground.algorithms.base_policy import Policy

Selection_and_Reward = namedtuple("SelectionReward", "selection, reward")


def _pg_loss(
    logits: torch.Tensor,
    selection: torch.Tensor,
    reward: torch.Tensor,
    training_kwargs=None,
):
    """Actually regret optimization.

    Args:
        logits (torch.Tensor): [description]
        selection (torch.Tensor): [description]
        reward (torch.Tensor): [description]

    Returns:
        [type]: [description]
    """
    # import pdb; pdb.set_trace()
    prob_mode = training_kwargs.get("policy_mode", "softmax")
    weights = torch.clip(selection.sum(dim=0), 1.0)
    reward_sum = (selection * reward).sum(dim=0)
    weighted_reward = reward_sum / weights
    if prob_mode == "softmax":
        probs = torch.softmax(logits, dim=-1)
        p = torch.softmax(weighted_reward, dim=-1).detach()
        # print("target dist:", p.numpy())
        loss = -torch.sum(p * torch.log(probs))
    else:
        probs = logits  # torch.softmax(logits, dim=-1)
        idx = weighted_reward.argmax()
        x = weighted_reward[idx]
        if x == 0.0:
            x = x + 1e-4
        loss = -probs[idx] * x
    # probs = torch.softmax(logits, dim=-1).view((1, -1))
    # loss = -torch.mean(torch.sum(probs * selection * reward, dim=-1))
    return loss


class LearnableDistribution:
    """Learnable distribution creates weight tensor for each added policy instance.
    It acts like a policy.
    """

    def __init__(self, policy_mapping: Dict[str, "Policy"], **training_kwargs):
        self._supports = []
        self._logits: Dict[str, torch.Tensor] = dict()
        self._tensor = None
        self._target_tensor = None
        self._training_kwargs = training_kwargs
        # sample mode can be rewrite by training kwargs, otherwise random by defaults.
        self._history: List[Selection_and_Reward] = []
        self._fixed_dist = None

    @property
    def target_tensor(self) -> torch.Tensor:
        return self._target_tensor

    @property
    def tensor(self) -> torch.Tensor:
        """Return a concatenation of logits.

        Returns:
            torch.Tensor: A tensor instance.
        """

        # return concatenation
        if self._fixed_dist is not None:
            return self._fixed_dist.logits
        else:
            return self._tensor

    @property
    def history(self) -> List[Selection_and_Reward]:
        return self._history

    def pop(self, key):
        idx = self._supports.index(key)
        self._supports = self._supports[:idx] + self._supports[idx + 1 :]
        self._logits.pop(key)
        # self._distribution = None
        self._tensor = torch.cat([self._logits[k] for k in self._supports], dim=-1)
        self._target_tensor = torch.cat(
            [self._logits[k] for k in self._supports], dim=-1
        )

    def keys(self) -> Tuple:
        """Return a tuple of supports.

        Returns:
            Tuple: A tuple of supports
        """

        return tuple(self._supports)

    def set_states(
        self, logits: Dict[str, float] = None, probs: Dict[str, float] = None
    ):
        if logits is not None:
            for k, logit in logits.items():
                assert k in self._logits, "%s" % list(self._logits.keys())
                self._logits[k].data[0] = logit
                self._tensor = torch.cat(
                    [self._logits[k] for k in self._supports], dim=-1
                )
                self._tensor.requires_grad = True
                self._target_tensor = torch.clone(self._tensor)
        if probs is not None:
            # convert probs to logits
            probs = torch.Tensor([probs[k] for k in self._supports])
            # self._fixed_dist = Categorical(probs=probs)
            handler = Categorical(probs=probs)
            logits = handler.logits
            for i, k in enumerate(self._supports):
                self._logits[k].data[0] = 2.0 * np.random.random() - 1.0
            self._tensor = torch.cat([self._logits[k] for k in self._supports], dim=-1)
            self._tensor.requires_grad = True
            self._target_tensor = torch.clone(self._tensor)

    def sample(self):
        """Sample a policy id from a categorical distribution

        :return: the sampled PolicyID
        """

        # probs = misc.gumbel_softmax(self.tensor, hard=True, explore=True)
        if self._fixed_dist is not None:
            dist = self._fixed_dist
            idx = dist.sample()
        else:
            sample_mode = self._training_kwargs.get("sample_mode", "random")
            if sample_mode == "soft":
                dist = Categorical(logits=self.tensor)
                idx = dist.sample()
            elif sample_mode == "random":
                idx = np.random.choice(len(self._supports))
            elif sample_mode == "explore":
                dist = misc.gumbel_softmax(logits=self.tensor, hard=False, explore=True)
                idx = dist.argmax(-1).item()
            elif sample_mode == "distill":
                raise ValueError("Unexpectd sample_mode: {}".format(sample_mode))
        # idx = probs.argmax(-1).item()
        support = self._supports[idx]
        return support

    def __getitem__(self, item: PolicyID) -> float:
        """Get the weight of selected policy id"""

        return self._logits[item].item()

    def __setitem__(self, key: PolicyID, value: float):
        """Set logits

        Args:
            key (PolicyID): Support key.
            value (float): Weight or logit.
        """

        self._history = []
        if key not in self._supports:
            # update support
            self._supports.append(key)
            self._logits[key] = torch.tensor(
                [value],
                dtype=torch.float32,  # requires_grad=True
            )

        if self._logits.get(key) is not None:
            self._logits[key].data[0] = value
            assert self._logits[key].data[0] == value, self._logits[key].data

        # reset all
        for v in self._logits.values():
            v.data[0] = 2.0 * np.random.random() - 1.0
            # v.data[0] += np.random.random()
            # v.data[0] += 2. * np.random.random() - 1.
        self._tensor = torch.cat([self._logits[k] for k in self._supports], dim=-1)
        self._tensor.requires_grad = True
        self._target_tensor = torch.clone(self._tensor)

    def onehot_from_logits(self) -> torch.Tensor:
        """Return a onehot from trainable logits.

        Returns:
            torch.Tensor: A onehot tensor.
        """

        logits = self.tensor
        return misc.onehot_from_logits(logits)

    def dict_values(self) -> Dict[PolicyID, float]:
        """Return a dict of probs.

        Returns:
            Dict[PolicyID, float]: A dict of probs,
        """

        p = self.probs().detach().numpy().tolist()
        return dict(zip(self._supports, p))

    def probs(self) -> torch.Tensor:
        # for training
        if self._fixed_dist is not None:
            return self._fixed_dist.probs
        else:
            return torch.softmax(self.tensor, dim=-1)

    def sync_logits(self):
        """Sync the newest tensor to logits"""

        for i, support in enumerate(self._supports):
            self._logits[support].data.copy_(self._tensor[i].data)

    def target_probs(self) -> torch.Tensor:
        return torch.softmax(self._target_tensor, dim=-1)

    def update_target(self, tau=0.01):
        self._target_tensor.data.copy_(
            self._target_tensor.data * (1.0 - tau) + self._tensor.data * tau
        )

    def optimize(self, **kwargs):
        """Optimize the lowest one.

        Raises:
            e: [description]
        """

        selection, reward = map(np.array, zip(*self._history))
        # convert selection and reward to tensor
        try:
            selection = torch.FloatTensor(selection.copy()).to(torch.float32)
            if kwargs.get("gamma", None) is not None:
                observed_index = kwargs["observed_index"]
                gamma = kwargs["gamma"]
                selection[:, observed_index] *= gamma
            reward = torch.FloatTensor(reward.copy()).to(torch.float32).view((-1, 1))
        except Exception as e:
            print(
                "----------- sections:",
                selection,
                type(selection),
                selection.dtype,
                reward,
            )
            raise e

        # print("selection shaoe, reaeshap:", selection.shape, reward.shape)

        old_logits = self.tensor
        if self._training_kwargs["policy_mode"] == "regret":
            probs = torch.softmax(old_logits)
            selection = selection[-1]
            idx = selection.argmax()
            old_logits = old_logits + 1.0
            old_logits[idx] = old_logits[idx] - (1.0 - reward[-1]) / probs[idx]
        else:
            loss = _pg_loss(old_logits, selection, reward, self._training_kwargs)
            # 0.01 (seed0), 0.001 (seed1), 0.005 (seed2) 0.01(seed3, discount, newest)
            optimizer = optim.Adam([old_logits], lr=self._training_kwargs["lr"])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        logits = dict(zip(self._supports, old_logits.detach().numpy().tolist()))
        self.set_states(logits=logits)

    def add_transition(self, policy_id: PolicyID, acc_reward: float):
        """Add transition, and train if the length of history meets `train_every`.

        Args:
            policy_id (PolicyID): Policy id.
            acc_reward (float): Accumulated reward collected by the chosen policy.
        """

        if len(self._supports) < 2:
            return
        index = self._supports.index(policy_id)
        # convert index to onehot
        one_hot = np.zeros(len(self._supports), dtype=np.float)
        one_hot[index] = 1.0
        gamma = 0.999
        self._history.append(Selection_and_Reward(one_hot, gamma))
        if len(self._history) == self._training_kwargs["train_every"]:
            self.optimize()
            self._history = []
        # print("** optimize joint policy since there has {} episodes **".format(len(self._history)))
        # if len(self._history) > self._training_kwargs["window_size"]:
        #     self.optimize(observed_index=index, gamma=gamma)
        #     self._history = self._history[1:]
        # # self._history = []
