import random
from tqdm import tqdm
from functools import partial
from collections import OrderedDict
import torch
import torch.optim as optim
#from torchcontrib.optim import SWA
import torch.nn as nn
import numpy as np 

import models as model_utils
from sklearn.linear_model import LogisticRegression

import os

from math import sqrt
from utils import *
from client import Device
from higher.patch import buffer_sync, make_functional
device = 'cuda' if torch.cuda.is_available() else 'cpu'



class Server(Device):
  def __init__(self, model_names, loader, val_loader, n_participants, num_classes=10, start_optimize=0):
    super().__init__(loader)
    self.val_loader = val_loader

    self.model_dict = {model_name : partial(model_utils.get_model(model_name)[0], num_classes=num_classes)().to(device) for model_name in model_names}

    self.parameter_dict = {model_name : {key : value for key, value in model.named_parameters()} for model_name, model in self.model_dict.items()}

    self.avg_parameter_dict = {model_name : {key : value for key, value in model.named_parameters()} for model_name, model in self.model_dict.items()}

    self.number_client_all = 0

    self.find_acc = 0
    self.weights = None

    self.mu = None
    self.pi = None
    self.var = None
    self.label_acc = None

    self.start_optimize = start_optimize

    
    self.models = list(self.model_dict.values())


  def evaluate_ensemble(self, loader=None):
    return eval_op_ensemble(self.models, self.loader, self.val_loader)


  def select_clients(self, clients, frac=1.0, unbalance_rate=1, sample_mode="uniform"):
    if sample_mode == "uniform":
      return random.sample(clients, int(len(clients)*frac))


    

  def fedavg(self, clients):
    unique_client_model_names = np.unique([client.model_name for client in clients])
    self.weights = torch.Tensor([1. / len(clients)] * len(clients))
    for model_name in unique_client_model_names:
      reduce_average(target=self.parameter_dict[model_name], sources=[client.W for client in clients if client.model_name == model_name])


  def sync_bn(self):
    for model in self.models:
      model.train()
      for x, _ in self.val_loader:
        x = x.to(device)
        y = model(x)
      # import pdb; pdb.set_trace()


  def optimize(self, clients, distill_iter, distill_lr, c_round):
    unique_client_model_names = np.unique([client.model_name for client in clients])

    weights_ori = torch.full([len(clients)], 1./len(clients), dtype=torch.float, requires_grad=True, device="cuda")
    weights = torch.nn.ReLU()(weights_ori)
    weights = weights/ weights.sum()
    weight_optimizer = torch.optim.Adam([weights_ori], lr=distill_lr)
    for model_name in unique_client_model_names:
      reduce_average(target=self.parameter_dict[model_name], sources=[client.W for client in clients if client.model_name == model_name])
    self.sync_bn()
    if c_round >= self.start_optimize:
      for it in range(distill_iter):
        for x, true_y in self.val_loader:
          x = x.to(device)
          true_y = true_y.to(device)
          for model in self.models:
            weighted_sum = []
            weights = torch.nn.ReLU()(weights_ori)
            weights = weights/ weights.sum()
            for name, value in model.named_parameters():
              weighted_sum.append(torch.sum(weights * torch.stack([client.W[name].detach() for client in clients], dim = -1), dim=-1))
            model_patched = make_functional(model)
            buffer_sync(model, model_patched)
            y_p = model_patched(x, params = weighted_sum)
            loss = torch.nn.CrossEntropyLoss(reduction="mean")(y_p, true_y)
            weight_optimizer.zero_grad()
            loss.backward()
            weight_optimizer.step()
        print (f"it {it} loss : {loss}")
    self.weights = weights.detach()
    for model_name in unique_client_model_names:
      reduce_weighted(target=self.parameter_dict[model_name], sources=[client.W for client in clients if client.model_name == model_name], weights = self.weights)

