import numpy as np
import pandas as pd
import random as rd
from collections import defaultdict
from scipy.special import binom
from typing import Tuple
from utils import Console

# Dictionary of probability laws
laws = {"Ball": "B", "Distance": "D", "Uniform": "U"}


class Sampler:
    # Constructor
    def __init__(
        self, features: pd.Index, name: str, radius=0, spread=0.0, verbose=True
    ):
        self.console = Console(verbose=verbose)
        self.features = features
        self.nbFeatures = len(features)
        self.lawName = name
        self.law = laws.get(name)
        self.radius = radius
        self.spread = spread

    # Print sampler attributes
    def __str__(self):
        str = self.console.string("Sampler Information", endl=True)
        str += self.console.string("Law", self.lawName, endl=True)
        str += self.console.string("Radius", self.radius, endl=True)
        str += self.console.string("Spread", self.spread, endl=True)
        return str

    # Setup sampler by constructing groups and distribution
    def setup(self) -> None:
        self.console.log("Setup sampler")
        self.setGroups()
        self.setDistribution()

    # Get instance to explain, in a deterministic way (from id) or in a random way
    def query(self, instances: pd.DataFrame, id=-1) -> np.ndarray:
        self.console.log("Get instance to explain")
        if id > -1:
            self.centerID = id
        else:
            self.centerID = np.random.randint(0, instances.shape[0])
        self.center = instances.iloc[self.centerID].to_numpy()
        return self.center

    # Get nb_samples from distance-based distribution
    def sample(self, *, nb_samples: int) -> np.ndarray:
        self.console.log("Get samples", nb_samples)

        # Store samples
        samples = np.zeros((nb_samples, self.nbFeatures), dtype="int64")
        # Encode the query
        encoded_center = self.encode(self.center)
        for m in range(nb_samples):
            # Generate encoded sample
            encoded_sample = self.getEncodedSample(center=encoded_center)
            samples[m] = self.decode(encoded_sample)
        # return dataframe
        # df_samples = pd.DataFrame(samples, columns = self.features, dtype='int64')
        return samples

    # Split features into groups to enforce domain constraints
    def setGroups(self) -> None:
        self.console.log("Set feature groups")

        # Get the category (prefix) of each feature, and store positions of features with same category
        categories = defaultdict(list)
        for i, c in enumerate(self.features):
            if "_" in c:
                k = c.split("_", 1)[0]
                categories[k].append(i)
            else:
                categories[c].append(i)
        self.nbGroups = len(categories.keys())
        # Build an array of positions where the last index is the dimension of data
        self.groups = np.zeros(self.nbGroups + 1, dtype="int64")
        for i, k in enumerate(categories.keys()):
            self.groups[i] = categories.get(k)[0]
        self.groups[self.nbGroups] = self.nbFeatures

    # Set the distribution of the penalized binomial random variable
    def setDistribution(self) -> None:
        self.console.log("Set distribution for law", self.lawName)
        # For distance-based and uniform distribution all distances are allowed
        if self.law != "B":
            self.radius = self.nbGroups
        # Build an array of binomial coefficients, each associated with a distance in {0,...,radius}
        binomials = np.array(
            [binom(self.nbGroups, k) for k in range(self.radius + 1)], dtype="float64"
        )
        # Penalize each coefficient according to the distance k
        if self.law == "D":
            # For distance-based distributions, the penalty is an exponential factor
            weights = np.array(
                [
                    binomials[k] * np.exp(-self.spread * k)
                    for k in range(self.radius + 1)
                ],
                dtype="float64",
            )
        else:
            # For ball and uniform distributions, the penalty is just 1
            self.spread = 0.0
            weights = binomials
        # Normalize distribution
        self.probabilities = weights / weights.sum()

    # Evaluate the Hamming distance between two arrays
    def getHammingDistance(self, x: np.ndarray, y: np.ndarray) -> int:
        return np.count_nonzero(x != y)

    # Encode boolean instance into integer vector using group positions
    def encode(self, instance: np.ndarray) -> np.ndarray:
        encoded = np.zeros(self.nbGroups, dtype="int64")
        for i in range(self.nbGroups):
            # get the feature values in group i
            slice = instance[self.groups[i] : self.groups[i + 1]]
            # Encode this group of values into an integer
            encoded[i] = int("".join(map(str, slice)), 2)
        return encoded

    # Decode integer instance into boolean vector using group positions
    def decode(self, instance: np.ndarray) -> np.ndarray:
        decoded = np.zeros(self.nbFeatures, dtype="int64")
        for i in range(self.nbGroups):
            # get width of group i
            w = self.groups[i + 1] - self.groups[i]
            # Decode the integer into binary array of specified width
            slice = np.array(list(np.binary_repr(instance[i], width=w)), dtype="int64")
            decoded[self.groups[i] : self.groups[i + 1]] = slice
        return decoded

    # Get encoded sample from distance law using encoded center
    def getEncodedSample(self, center: np.ndarray) -> np.ndarray:
        # Generate a number of groups to flip
        distance = np.random.choice(self.radius + 1, p=self.probabilities)
        # self.console.log("Generated distance", arg=f"{distance}")
        # Generate an array with d groups to flip, where d is the generated distance
        flips = np.random.choice(self.nbGroups, distance, replace=False)
        # self.console.log("Generated groups to flip", arg=f"{flips}")
        # Flip the value of each group
        encoded_instance = center.copy()
        for i in flips:
            # Get the number of features in the group
            w = self.groups[i + 1] - self.groups[i]
            # Use the fact that each encoding of feature value is a power of 2
            candidates = np.power(2, np.arange(w))
            # Remove the feature value in the query
            candidates = np.delete(candidates, np.where(candidates == center[i]))
            # Generate uniformly at random a value in the remaining candidates
            encoded_instance[i] = np.random.choice(candidates)
            # Debug
            # self.console.log("Group", arg=f"{i}")
            # self.console.log("Center", arg=f"{center[i]}")
            # self.console.log("Candidates", arg=f"{candidates}")
            # self.console.log("Choice", arg=f"{encoded_instance[i]}")
        # Decode the resulting instance
        return encoded_instance
