import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from itertools import chain, combinations
import math
from scipy.optimize import minimize
from numpy import linalg as LA
import random
import bisect


class ModelValueEstimator():
    def __init__(self, num_party):
        self.weights_sum = np.zeros(num_party)
        self.T = 0
    
    def add_record(self, weights):
        self.weights_sum += weights
        self.T += 1
    
    def get_shapley(self):
        return self.weights_sum / self.T

def get_cos(a, b):
    return np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b))


def knowledge_vote(predictions, num_class=10, confidence_threshold = 0.95):
    # Step 1: Confidence Gate
    confident_indices = torch.where(predictions.max(dim=1).values >= confidence_threshold)[0]
    # print(len(confident_indices))
    confident_predictions = predictions[confident_indices]

    # If there are no confident predictions, return None
    if confident_predictions.shape[0] == 0:
        return None, None, None, None

    # Step 2: Consensus Vote
    consensus_class = confident_predictions.sum(dim=0).argmax().item()

    supporting_indices = torch.where(confident_predictions.argmax(dim=1) == consensus_class)[0]

    if confident_indices[supporting_indices].shape[0] == 0:
        return None, None, None, None

    return confident_indices[supporting_indices], confident_predictions.mean(dim=0).unsqueeze(0), consensus_class, len(confident_indices)


def one_hot_encode(label, num_classes):
    one_hot = np.zeros(num_classes)
    one_hot[label] = 1
    return one_hot

def get_optimal_weights(models, X, y, num_class):
    # Store Preds First
    preds = []
    with torch.no_grad():
        prediction = np.zeros_like(y)
        for model in models:
            output = model(X)
            pred = F.softmax(output,dim=1).detach().cpu().numpy()[0]
            preds.append(pred)

    def loss(weights):
        # Compute the weighted average prediction
        prediction = np.zeros_like(y)
        for weight, pred in zip(weights, preds):
            prediction += weight * pred
        # Compute the difference between the weighted average prediction and the one-hot encoded label
        diff = np.abs(prediction - y).mean()
        return diff
  
    # Set the initial weights to equal values
    x0 = np.ones(len(models)) / len(models)

    cons = [{'type': 'eq',
             'fun': lambda w: w.sum() - 1}]
    # Set the bounds on the weights to be between 0 and 1
    bounds = [[0, 1] for _ in range(len(models))]

    # Call scipy.optimize.minimize to find the optimal weights
    res = minimize(loss, x0, bounds=bounds, method='SLSQP', options={'disp':False, 'maxiter':5000}, constraints=cons)

    return res.x

def sample_data_points(Ts, all_idx):
    data_idxs = []
    prev_data = set()
    pre_Ti = 0
    for Ti in Ts:
        diff_len = list(prev_data ^ set(all_idx))
        Di = set(np.random.choice(diff_len, Ti-pre_Ti, replace=False))
        Di.update(prev_data)
        data_idxs.append(Di)
        prev_data = Di
        pre_Ti = Ti
    return data_idxs