import torch
import os
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import random
from collections import deque
from utils.model_utils import get_log_path
from net.models import build_model
from torch.utils.data import DataLoader
from server.collaboration import *


class Server:
    def __init__(self, args, distill_data):
        # Main Features:
        # 1. query each party: collect (data, prediction) pairs 
        # 2. train student model and generator
        # 3. valuation on coalition
        # Set up the main attributes
        random.seed(42)
        self.num_users = args.num_users
        self.parties = []
        self.device = torch.device(args.cuda_num)
        self.distill_data = distill_data
        self.distill_loader = torch.utils.data.DataLoader(distill_data,batch_size=1, shuffle=True)
        self.rewards = []
        self.valuator = ModelValueEstimator(args.num_users)

    def add_party(self, party):
        self.parties.append(party)

    def query(self, data):
        # consider logits or predictions
        logits  = []
        for party in self.parties:
            # data shape: MNIST 1x28x28 (no batchsize)
            logit = party.get_pred(data.to(self.device)) # BxK
            logits.append(logit)
        return np.array(logits)
    
    def collab(self, ensemble_method='avg'):
        print('[Ensemble Distillation]')
        '''
        Use the weighted average ensemble method to distill the ensemble knowledge.
        '''
        count = 0
        for data, y in self.distill_loader:
            ens_logits = self.query(data) # list of len N where each is (1 x K): N x 1 x K
            if ensemble_method == 'avg':
                ens_pred = torch.from_numpy(ens_logits.mean(0))
                weights = np.ones(self.num_users)/self.num_users
            elif ensemble_method == 'opt':
                one_hot_y = one_hot_encode(y,10)
                models = []
                for party in self.parties:
                    models.append(party.model)
                weights = get_optimal_weights(models, data.to(self.device), one_hot_y, 10)
                ens_pred = torch.from_numpy(one_hot_y).unsqueeze(0)
                # weights
            elif ensemble_method == 'kv':
                weights = np.zeros(self.num_users)
                indices, ens_pred, _ = knowledge_vote(torch.from_numpy(ens_logits).squeeze(1))
                if ens_pred is not None:
                    weights[indices] += 1 / len(indices)
                else:
                    ens_pred = torch.from_numpy(ens_logits.mean(0))
                    weights = np.ones(self.num_users)/self.num_users

            assert abs(weights.sum() - 1) < 1e-3
            self.valuator.add_record(weights)
            self.rewards.append((data, ens_pred))

    def reward(self):
        print('[Allocate Rewards]')
        print('Total Rewards: ', len(self.rewards))
        shapleys = self.valuator.get_shapley()
        print('Shapleys: ', shapleys)
        dict_rewards = {}
        for i, party in enumerate(self.parties):
            shap = shapleys[i] / np.max(shapleys)
            T_i = (np.floor(shap * len(self.rewards))).astype(int)
            dict_rewards[i] = T_i
        sorted_keys = sorted(dict_rewards, key=lambda k: dict_rewards[k])

        Ts = []
        for key in sorted_keys:
            Ts.append(dict_rewards[key])
        all_indexs = sample_data_points(Ts, list(range(len(self.rewards))))
        for i, indexs in enumerate(all_indexs):
            for idx in indexs:
                X_idx, y_idx = self.rewards[idx]
                self.parties[sorted_keys[i]].reward_data.append(X_idx[0].tolist())
                self.parties[sorted_keys[i]].reward_label.append(y_idx[0].tolist())
            print('Reward Party ', sorted_keys[i], ' ', len(indexs), ' Ensemble Predictions | Shapley: ', shapleys[sorted_keys[i]])
            

    def save_model(self):
        model_path = os.path.join("models", self.dataset)
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        torch.save(self.student_model, os.path.join(model_path, "server" + ".pt"))


    def load_model(self):
        model_path = os.path.join("models", self.dataset, "server" + ".pt")
        assert (os.path.exists(model_path))
        self.student_model = torch.load(model_path)


    def evaluate_ensemble(self):
        with torch.no_grad():
            print('---Evaluate Ensemble Model on test dataset---')
            test_acc = 0
            total = 0
            for x, y in self.parties[0].testloaderfull:
                x = x.to(self.device)
                y =y.to(self.device)
                ens_logits = self.query(x)
                output = F.softmax(torch.from_numpy(ens_logits.sum(0)), dim=1).to(self.device)
                test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item()
                total += self.parties[0].test_samples
        print(f"Ensemble Test Accuracy: {test_acc/total:.4}, Total Tested Samples: {total}")