import random
from tqdm import tqdm
from functools import partial
from collections import OrderedDict
import torch
import torch.optim as optim
#from torchcontrib.optim import SWA
import torch.nn as nn
import numpy as np 
from utils import *
import models as model_utils
from sklearn.linear_model import LogisticRegression

import os

# from gmm_torch.gmm import GaussianMixture
from math import sqrt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
class Device(object):
  def __init__(self, loader):
    
    self.loader = loader

  def evaluate(self, loader=None):
    return eval_op(self.model, self.loader if not loader else loader)

  def save_model(self, path=None, name=None, verbose=True):
    if name:
      torch.save(self.model.state_dict(), path+name)
      if verbose: print("Saved model to", path+name)

  def load_model(self, path=None, name=None, verbose=True):
    if name:
      self.model.load_state_dict(torch.load(path+name))
      if verbose: print("Loaded model from", path+name)
  
class Client(Device):
  def __init__(self, model_name, optimizer_fn, loader, idnum=0, num_classes=10, images_train=None, labels_train=None, eta=0.5):
    super().__init__(loader)
    self.id = idnum

    self.model_name = model_name
    self.model_fn = partial(model_utils.get_model(self.model_name)[0], num_classes=num_classes)
    self.model = self.model_fn().to(device)

    self.W = {key : value for key, value in self.model.named_parameters()}

    self.optimizer_fn = optimizer_fn
    self.optimizer = self.optimizer_fn(self.model.parameters())

    
  def synchronize_with_server(self, server):
    server_state = server.model_dict[self.model_name].state_dict()
    self.model.load_state_dict(server_state, strict=False)

    
  def compute_weight_update(self, epochs=1, loader=None, lambda_fedprox=0.0):
    train_stats = train_op(self.model, self.loader if not loader else loader, self.optimizer, epochs, lambda_fedprox=lambda_fedprox)
    return train_stats



  def predict_logit(self, x):
    """Softmax prediction on input"""
    self.model.train()

    with torch.no_grad():
      y_ = self.model(x)

    return y_
  
  def predict_logit_eval(self, x):
    """Softmax prediction on input"""
    self.model.eval()
    with torch.no_grad():
      y_ = self.model(x)

    return y_

