import pandas as pd
import random
from typing import List, Tuple
from itertools import permutations
from scipy.stats import kendalltau
import numpy as np

class Vote:
    def __init__(self, question_number: int, options: List[int], ranking: List[int], predicted_probs: dict, is_expert: bool):
        self.question_number = question_number
        self.options = options
        self.ranking = ranking
        self.predicted_probs = predicted_probs
        self.is_expert = is_expert

class Voter:
    def __init__(self, is_expert: bool):
        self.is_expert = is_expert

    def vote(self, question_number: int, options: List[int], ground_truth: List[int], all_worlds: List[Tuple[int]]):
        # Experts vote for ground truth, non-experts do not
        if self.is_expert:
            centroid = ground_truth
        else:
            centroid = random.choice(all_worlds)

        # Compute the Mallows probabilities for all possible signals
        signal_probs = {}
        for possible_signal in all_worlds:
            signal_probs[possible_signal] = signal_probability(possible_signal, centroid, self.is_expert)

        # Normalize the probabilities so they sum to 1
        total_prob = sum(signal_probs.values())
        normalized_signal_probs = {signal: prob/total_prob for signal, prob in signal_probs.items()}

        # Choose a signal based on the probabilities
        signal = random.choices(list(normalized_signal_probs.keys()), weights=list(normalized_signal_probs.values()), k=1)[0]

        # Compute the conditional probabilities and the predicted ranking
        conditional_probs = {}
        for possible_signal in all_worlds: 
            s_j = possible_signal
            s_k = signal
            conditional_probs[s_j] = compute_conditional_prob(s_j, s_k, all_worlds, ground_truth, self.is_expert, prediction=True)

        # Call the predict method to get the predicted probabilities
        predicted_probs = self.predict(signal, conditional_probs, all_worlds, ground_truth)

        # Determine prediction by finding the ranking with the highest predicted probability
        prediction = max(predicted_probs, key=predicted_probs.get)

        return Vote(question_number, options, signal, prediction, self.is_expert)

    def predict(self, signal, conditional_probs, all_worlds, ground_truth):
        # Prepare the prediction probabilities.
        prediction_probs = {world: prob for world, prob in conditional_probs.items() if world != signal}

        # Normalize the probabilities so that they sum to 1.
        total_prob = sum(prediction_probs.values())
        normalized_prediction_probs = {world: prob/total_prob for world, prob in prediction_probs.items()}

        return normalized_prediction_probs


def generate_subsets(n, k, step):
    alternatives = list(range(1, n+1))
    subsets = []

    start = 0
    while start + k * step - step < n:  # Adjust the loop condition for stepped subsets
        subset = [alternatives[start + i*step] for i in range(k)]
        if subset not in subsets:  # Ensure no duplicate subsets
            subsets.append(subset)
        start += 1  # Adjust start for overlap. Can change this for more or less overlap.
        
    return subsets


def mallows_distance(ranking1, ranking2):
    tau, _ = kendalltau(ranking1, ranking2)
    return 1 - tau  # Inverting tau to represent a 'distance'

def normalization_constant(phi, m):
    z = 1
    for i in range(1, m):
        z *= sum(phi**j for j in range(i + 1))
    return z

def signal_probability(signal, world, is_expert):
    # dispersion parameters
    dispersion_expert_vote = abs(np.random.normal(0.15, 0.075))
    dispersion_nonexpert_vote = abs(np.random.normal(0.7, 0.3))

    # Select appropriate dispersion based on the voter's expertise
    dispersion = dispersion_expert_vote if is_expert else dispersion_nonexpert_vote

    # Calculate the distance using the appropriate metric, such as Kendall's tau
    distance = mallows_distance(signal, world)

    # Number of alternatives
    m = 5

    # Compute the probability using the Mallows model
    phi = dispersion 
    prob = phi**distance / normalization_constant(phi, m)

    return prob

computed_posteriors = {}

def compute_posterior(signal, world, all_worlds, ground_truth, is_expert):
    # Check if the posterior has already been computed
    key = (tuple(signal), world, tuple(ground_truth), is_expert)  # Convert lists to tuples to use as dict keys
    if key in computed_posteriors:
        return computed_posteriors[key]
    
    # If not precomputed, then compute it
    prior = 1 / len(all_worlds)  # uniform prior
    likelihood = signal_probability(signal, world, is_expert)

    # compute the total probability of the signal
    total_signal_prob = sum(signal_probability(signal, w, is_expert) * prior for w in all_worlds)

    # compute the posterior probability using Bayes' theorem
    posterior = likelihood * prior / total_signal_prob
    
    # Store the computed posterior in the dictionary
    computed_posteriors[key] = posterior
    
    return posterior



def compute_conditional_prob(s_j, s_k, all_worlds, ground_truth, is_expert, num_samples=1000, prediction=False):
    sampled_worlds = random.choices(all_worlds, weights=[compute_posterior(s_k, w, all_worlds, ground_truth, is_expert) for w in all_worlds], k=num_samples)
    
    total_prob = 0.0
    for world_i in sampled_worlds:
        # The probability of signal s_j given world_i
        p_sj_wi = signal_probability(s_j, world_i, is_expert)
        total_prob += p_sj_wi

    return total_prob / num_samples



def simulate_voting(num_voters: int, subsets: List[List[int]], ground_truths: List[List[int]]) -> Tuple[List[Vote]]:
    prob_expert = np.random.beta(1, 2.5)

    # Sample the number of experts based on the probability
    num_experts = np.random.binomial(num_voters, prob_expert)
    voters = [Voter(is_expert=(i < num_experts)) for i in range(num_voters)]
    
    votes = []
    for question_number, subset in enumerate(subsets, start=1):
        # Provide the ground truth to the voters when they are casting their votes.
        ground_truth = ground_truths[question_number - 1]

        # Compute the prior distribution and all worlds (permutations) here
        all_worlds = all_worlds = list(permutations(subset))


        for voter in voters:
            vote = voter.vote(question_number, subset, ground_truth, all_worlds)
            votes.append(vote)

    return votes


def write_to_csv(votes):
    data = []
    for vote in votes:
        data.append([vote.question_number, vote.options, vote.ranking, vote.predicted_probs, vote.is_expert, 1, 6])
        
    df = pd.DataFrame(data, columns=['question', 'options', 'votes', 'predictions', 'is_expert', 'domain', 'treatment'])
    df.to_csv('simulated_data.csv', index=False)


# Test the simulation
num_voters = 10
n = 36 # Total number of alternatives
k = 5  # Size of each subset
step = 6# Spacing between elements in each subset

subsets = generate_subsets(n, k, step)
# Generate ground truths for each subset. In this example, I assume the ground truth 
# is the options sorted in ascending order. You can replace this with your actual ground truths.
ground_truths = [sorted(subset, reverse=False) for subset in subsets]

votes = simulate_voting(num_voters, subsets, ground_truths)

write_to_csv(votes)