import random
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from scipy.stats import ks_2samp, cramervonmises_2samp

class BCvMAlgorithm:
    def __init__(self, feature_maps, feature_coeffs, g_functions): 
        '''
        feature_maps: list of functions that map data points to a feature space
        feature_coeffs: list of coefficients for each feature map
        g_functions: list of g_functions for each feature map
        '''
        self.feature_maps = feature_maps
        self.feature_coeffs = feature_coeffs
        self.g_functions = g_functions

        if len(feature_maps) != len(feature_coeffs) or len(feature_maps) != len(g_functions):
            raise ValueError('feature_maps, feature_coeffs, and g_functions must have the same length.')

    def generate_statistics(self, submissions):
        '''
        submissions: list of lists, where each inner list is a submission (list of data points) from an agent
        '''
        test_statistics = []
        for i in range(len(submissions)):
            Y_mi = np.concatenate(submissions[:i] + submissions[i+1:]) # Create Y_{-i}

            if len(Y_mi) <= 1:
                test_statistics.append(0)
            else:
                D_i = 0
                for l, feature_map in enumerate(self.feature_maps):
                    # Choose a random point from Y_mi as T and remove it from Y_mi
                    rand_index = random.randrange(len(Y_mi))

                    Z_il = [feature_map(y) for ind, y in enumerate(Y_mi) if ind != rand_index] 
                    T = feature_map(Y_mi[rand_index])

                    T_order = sum(1 for z in Z_il if z <= T) 
                    ecdf_at_T = T_order / len(Z_il)

                    g_il_at_T = self.g_functions[l](submissions[i], T) 

                    D_il = (g_il_at_T - ecdf_at_T)**2 
                    D_i += self.feature_coeffs[l] * D_il

                test_statistics.append(D_i)
        return test_statistics


class KSAlg:
    def generate_statistics(self, submissions):
        '''
        submissions: list of lists, where each inner list is a submission (list of data points) from an agent
        '''
        test_statistics = []
        for i in range(len(submissions)):
            Y_mi = np.concatenate(submissions[:i] + submissions[i+1:]) # Create Y_{-i}

            test_statistics.append(ks_2samp(Y_mi, submissions[i]).statistic)
        return test_statistics

class CVMAlg:
    def generate_statistics(self, submissions):
        '''
        submissions: list of lists, where each inner list is a submission (list of data points) from an agent
        '''
        test_statistics = []
        for i in range(len(submissions)):
            Y_mi = np.concatenate(submissions[:i] + submissions[i+1:]) # Create Y_{-i}

            test_statistics.append(cramervonmises_2samp(Y_mi, submissions[i]).statistic)
        return test_statistics

class DiffMeans:
    def generate_statistics(self, submissions):
        '''
        submissions: list of lists, where each inner list is a submission (list of data points) from an agent
        '''
        test_statistics = []
        for i in range(len(submissions)):
            Y_mi = np.concatenate(submissions[:i] + submissions[i+1:]) # Create Y_{-i}

            agent_mean = np.mean(submissions[i])
            others_mean = np.mean(Y_mi)
         

            test_statistics.append(np.abs(agent_mean-others_mean))
        return test_statistics


class Agent:
    def __init__(self, submission_func, data_generator):
        self.submission_func = submission_func
        self.data_generator = data_generator
        self.data = None

    def generate_dataset(self, size):
        self.data = self.data_generator.generate_data(size)

    def report_dataset(self):
        if self.data is None:
            raise RuntimeError("Data not generated. Call generate_dataset before submitting data.")
        return self.submission_func(self.data)


class DataSharingExperiment:
    def __init__(self, algorithm, submission_functions, data_generator):
        '''
        algorithm: an instance of BCvMAlgorithm capturing the model and instantiated with the necessary inputs
        submission_functions: list of functions that each agent will use to submit their data
        data_generator: an instance of DataGenerator that will be used to generate data for each agent
        '''
        self.algorithm = algorithm
        self.agents = []
        self.data_generator = data_generator

        for submission_func in submission_functions:
            new_agent = Agent(submission_func=submission_func, data_generator=data_generator)
            self.agents.append(new_agent)

    def _single_run(self, run_idx, num_data_per_agent, base_seed):
        """
        This will be executed in a separate process.
        """
        # 1) Reseed Python and NumPy (SciPy.stats uses NumPy RNG under the hood)
        seed = base_seed + run_idx
        random.seed(seed)
        np.random.seed(seed)

        # 2) Resample prior, generate data & submissions
        self.data_generator.resample_from_prior()
        submissions = []
        for agent in self.agents:
            agent.generate_dataset(num_data_per_agent)
            submissions.append(agent.report_dataset())

        # 3) Compute stats and return
        return self.algorithm.generate_statistics(submissions=submissions)

    def run_experiment(self, runs, num_data_per_agent, seed=None, n_jobs=8):
        """
        runs: number of independent experiments
        num_data_per_agent: samples per agent each run
        seed: int or None.  If int, experiment is fully reproducible.
        n_jobs: number of worker processes to spin up.
        """
        # Choose a reproducible base seed if none given
        if seed is None:
            seed = random.randrange(2**32)
            print(f"[run_experiment] No seed specified using base_seed={seed}")

        base_seed = seed

        # Prepare to collect results
        test_stats_over_runs = np.zeros((runs, len(self.agents)))

        # Launch parallel runs
        with ProcessPoolExecutor(max_workers=n_jobs) as exe:
            # schedule all runs
            futures = {
                exe.submit(self._single_run, i, num_data_per_agent, base_seed): i
                for i in range(runs)
            }
            # as each one completes, store its output
            for fut in tqdm(as_completed(futures), total=runs):
                i = futures[fut]
                test_stats_over_runs[i, :] = fut.result()

        # average across runs
        average_performances = test_stats_over_runs.mean(axis=0)
        variance_performances = test_stats_over_runs.var(axis=0, ddof=1)

        return average_performances, variance_performances