"""
Sampler and voter classes.
"""

import numpy as np
import mallows_kendall as mk
from comparison_helpers import voter_unsatisfied, convert_candidates_to_ranks, convert_ranks_to_candidates
import choix


class Voter:
    """Generic voter class that can make pairwise comparisons.
    Also keeps track of total pairwise comparisons made for the voter.

    """

    def __init__(self, full_rank):
        self.total_pairwise_comparisons = 0 
        self.full_rank = full_rank
    
    def pairwise_prefer(self, cand1, cand2):
        """Returns True iff voter strictly prefers cand1 over cand2 according to full_rank.

        Args: 
          cand1: integer
          cand2: integer
          full_rank: numpy array of integers, where the position of a candidate integer indicates its rank.

        Returns: bool, True iff voter strictly prefers cand1 over cand2 according to full_rank.

        """
        self.total_pairwise_comparisons += 1
        return np.where(self.full_rank == cand1)[0][0] < np.where(self.full_rank == cand2)[0][0]

class Sampler:
    """Generic sampler class that can also sample from unsatisfied voters. 
    Also keeps track of total voters sampled.

    """

    def __init__(self, num_candidates, beta=0.9, store_voters=True, rejection_max_iters=10):
        self.total_samples = 0 
        self.total_samples_unrejected = 0
        self.num_candidates = num_candidates
        
        # Parameters for sampling from unsatisfied voters
        self.beta = beta
        self.all_prev_Ss = []
        self.all_prev_lotteries = []
        
        # Keep track of all voters 
        self.store_voters = store_voters
        if store_voters:
            self.all_voters = []
            self.unrejected_voters = []
        self.rejection_max_iters=rejection_max_iters
        
    def max_pairwise_comparisons(self):
        """Computes maximum number of pairwise comparisons for a single voter."""
        if not self.store_voters:
            return 0

        max_pwc = 0
        for voter in self.all_voters:
            if voter.total_pairwise_comparisons > max_pwc:
                max_pwc = voter.total_pairwise_comparisons
        return max_pwc


    def all_pairwise_comparisons(self):
        """Computes maximum number of pairwise comparisons for a single voter."""
        if not self.store_voters:
            return []

        return [voter.total_pairwise_comparisons for voter in self.all_voters]


    def max_pairwise_comparisons_unrejected(self):
        """Computes maximum number of pairwise comparisons for a single voter."""
        if not self.store_voters:
            return 0

        max_pwc = 0
        for voter in self.unrejected_voters:
            if voter.total_pairwise_comparisons > max_pwc:
                max_pwc = voter.total_pairwise_comparisons
        return max_pwc

    def total_pairwise_comparisons(self):
        """Computes maximum number of pairwise comparisons for a single voter."""
        if not self.store_voters:
            return 0

        total_pwc = 0
        for voter in self.all_voters:
            total_pwc += voter.total_pairwise_comparisons
        return total_pwc

    def total_pairwise_comparisons_unrejected(self):
        """Computes maximum number of pairwise comparisons for a single voter."""
        if not self.store_voters:
            return 0

        total_pwc = 0
        for voter in self.unrejected_voters:
            total_pwc += voter.total_pairwise_comparisons
        return total_pwc

    def sample_ranks(self, num_voters):
        """Samples full rankings from an original distribution."""
        raise NotImplementedError
    
    def sample_original(self, num_voters):
        """Samples voters from an original distribution.

        Args:
          num_voters: number of voters to sample.

        Returns: list of Voter objects.
        """
        samples = self.sample_ranks(num_voters)

        self.total_samples += num_voters
        voters = []
        for sample in samples:
            new_voter = Voter(sample)
            voters.append(new_voter)
            if self.store_voters:
                self.all_voters.append(new_voter)
        return voters
    
    def sample(self, num_voters, strict=False, batch_size=10, verbose=False):
        """Samples unsatisfied voters using rejection sampling."""
        unsatisfied_voters = []
        iters = 0
        # max_iters = (num_voters/batch_size) * max_iter_factor
        max_iters = (num_voters/batch_size) * self.rejection_max_iters
        while(len(unsatisfied_voters) < num_voters):
            iters += 1
            sampled_voters = self.sample_original(batch_size)
            for voter in sampled_voters: 
                if verbose:
                    print('voter rank', voter.full_rank)
                if voter_unsatisfied(
                    self.beta, 
                    voter, 
                    self.all_prev_Ss, 
                    self.all_prev_lotteries, 
                    strict=strict, 
                    verbose=verbose):            
                    if verbose:
                        print('unsatisfied')
                    unsatisfied_voters.append(voter)
                    if self.store_voters:
                        self.unrejected_voters.append(voter)
                        self.total_samples_unrejected += 1
            if iters >= max_iters:
                if verbose:
                    print("Max iters of %d reached, %d unsatisfied voters found" % (max_iters, len(unsatisfied_voters)))
                break

        return np.array(unsatisfied_voters)
    
class MallowsSampler(Sampler):
    """Class to sample from a single Mallows distribution. 

    Args:
      num_candidates: number of candidates
      center: numpy array of length num_candidates. Each value represents a candidate index, and the position represents the rank.
        Ex: [3,1,2] means candidate 3 is ranked first, candidate 1 is ranked second, and candidate 2 is ranked third.
      theta: dispersion parameter.
      beta: voter satisfaction parameter.

    """

    def __init__(self, num_candidates, store_voters=True, center=None, theta=1.5, beta=0.9, rejection_max_iters=10):
        super().__init__(num_candidates, beta=beta, store_voters=store_voters, rejection_max_iters=rejection_max_iters)
        if center is None:
            center = np.arange(num_candidates)
        self.center = np.array(center)
        self.center_with_values_as_ranks = convert_candidates_to_ranks(center, num_candidates)
        self.theta = theta
        self.values_are_candidates = values_are_candidates
        
    def sample_ranks(self, num_voters):
        """Samples full rankings from an original distribution.
        
        Returns: numpy array of shape (num_voters, len(center)). 
          Each row represents the full ranking of a single sampled voter.
        """
        # Convert center such that values represent ranks.
        samples = mk.sample(m=num_voters, n=self.num_candidates, theta=self.theta, s0=self.center_with_values_as_ranks)
        samples_with_values_as_candidates = [convert_ranks_to_candidates(sample, self.num_candidates) for sample in samples]
        return samples_with_values_as_candidates
    
class MallowsMixtureSampler(Sampler):
    """Class to sample from a mixture of Mallows. 

    Example: mixture_sampler = MallowsMixtureSampler(num_candidates=5, centers=[np.array([0,1,2,3,4]), np.array([4,3,2,1,0])], center_probs = [0.5, 0.5], thetas = [1.5, 1.5], beta=0.9)

    Args:
      num_candidates: number of candidates
      centers: list of numpy arrays of length num_candidates. Each entry defines one center.
      center_probs: probabilities of sampling from each center.
      thetas: dispersion parameter corresponding to each center.
      beta: voter satisfaction parameter.

    """

    def __init__(self, num_candidates, centers=[], center_probs=[], thetas=None, phis=None, beta=0.9, store_voters=True, rejection_max_iters=10):
        super().__init__(num_candidates, beta=beta, store_voters=store_voters, rejection_max_iters=rejection_max_iters)

        self.centers = centers
        self.centers_with_values_as_ranks = []
        for center in centers:
            self.centers_with_values_as_ranks.append(convert_candidates_to_ranks(center, num_candidates))
        self.centers_with_values_as_ranks = np.array(self.centers_with_values_as_ranks)

        self.center_probs = center_probs
        self.thetas = thetas
        self.phis = phis
        
    def sample_ranks(self, num_voters):
        """Samples full rankings from an original distribution.
        
        Returns: numpy array of shape (num_voters, len(center)). 
          Each row represents the full ranking of a single sampled voter.
        """
        # Decide number of voters to draw from each center.
        num_voters_per_center = np.random.multinomial(num_voters, self.center_probs, size=1).flatten()
        all_samples = []
        
        for i in range(len(self.centers_with_values_as_ranks)):
            cur_num_voters = num_voters_per_center[i]
            if cur_num_voters == 0:
                continue
                
            cur_center = self.centers_with_values_as_ranks[i]
            cur_theta = None
            cur_phi = None
            if self.thetas is not None:
                cur_theta = self.thetas[i]
            if self.phis is not None:
                cur_phi = self.phis[i]
            cur_samples = mk.sample(m=cur_num_voters, n=self.num_candidates, theta=cur_theta, phi=cur_phi, s0=cur_center)
            if len(all_samples) == 0:
                all_samples = cur_samples
            else:
                all_samples = np.vstack([all_samples, cur_samples])

        samples_with_values_as_candidates = [convert_ranks_to_candidates(sample, self.num_candidates) for sample in all_samples]
        # print(all_samples)
        # print(samples_with_values_as_candidates)
        return samples_with_values_as_candidates

