import os, argparse, json, time
import copy
from copy import deepcopy

from tqdm import tqdm
from functools import partial
import torch, torchvision
import numpy as np

import data , models 
from data import my_subset
from data import unbalanced_dataset
import experiment_manager as xpm
from client import Client
from utils import *
from server import Server
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
np.set_printoptions(precision=4, suppress=True)
def reduce_average(target, sources):
  for name in target:
      target[name].data = torch.mean(torch.stack([source[name].detach() for source in sources]), dim=0).clone()


import os

parser = argparse.ArgumentParser()
parser.add_argument("--schedule", default="main", type=str)
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--end", default=None, type=int)
parser.add_argument("--reverse_order", default=False, type=bool)
parser.add_argument("--hp", default=None, type=str)
parser.add_argument("--project", default=None, type=str)
parser.add_argument('--debugging', action="store_true")
parser.add_argument("--DATA_PATH", default=None, type=str)
parser.add_argument("--runs_name", default=None, type=str)
parser.add_argument("--RESULTS_PATH", default=None, type=str)
parser.add_argument("--ACC_PATH", default=None, type=str)
parser.add_argument("--CHECKPOINT_PATH", default=None, type=str)


args = parser.parse_args()


def run_experiment(xp, xp_count, n_experiments):
  t0 = time.time()
  print(xp)
  hp = xp.hyperparameters
 
  run = wandb.init(project = args.project, config = hp, reinit = True, name=args.runs_name)
  print(wandb.config)
  
  num_classes = {"mnist" : 10, "cifar10" : 10, "cifar100" : 100, "nlp" : 4, 'news20': 20}[hp["dataset"]]

  distill_iter = hp.get("distill_iter", None)
  distill_lr = hp.get("distill_lr", None)
  sample_size = hp.get("sample_size", None)

  model_names = [model_name for model_name, k in hp["models"].items() for _ in range(k)]
  
  optimizer, optimizer_hp = getattr(torch.optim, hp["local_optimizer"][0]), hp["local_optimizer"][1]
  optimizer_fn = lambda x : optimizer(x, **{k : hp[k] if k in hp else v for k, v in optimizer_hp.items()})
  
  print(f"dataset : {hp['dataset']}")
  train_data_all, test_data = data.get_data(hp["dataset"], args.DATA_PATH)
  split = hp["val_size"]
  np.random.seed(hp["random_seed"])
  torch.manual_seed(hp["random_seed"])
  indices = np.random.permutation(len(train_data_all))
  train_indices, val_indices = indices[split:], indices[:split]
  train_data = my_subset(train_data_all, train_indices, np.array(train_data_all.targets)[train_indices])
  val_data = deepcopy(train_data_all)
  val_data.data = val_data.data[val_indices]
  val_data.targets = np.array(val_data.targets)[val_indices]
  val_data = unbalanced_dataset(val_data, hp.get("val_imbalance_factor", -1))
  client_loaders, test_loader = data.get_loaders(train_data, test_data, n_clients=len(model_names),
        alpha=hp["alpha"], batch_size=hp["batch_size"], n_data=None, num_workers=0, seed=hp["random_seed"])
  val_loader = torch.utils.data.DataLoader(val_data, batch_size=hp["val_batch_size"], shuffle=False)
  server = Server(np.unique(model_names), test_loader,val_loader,n_participants= hp["participation_rate"]* len(client_loaders),num_classes=num_classes, start_optimize=hp.get("start_optimize", 0))
  clients = [Client(model_name, optimizer_fn, loader, idnum=i, num_classes=num_classes) for i, (loader, model_name) in enumerate(zip(client_loaders, model_names))]

  server.number_client_all = len(client_loaders)

  models.print_model(clients[0].model)
  # Start Distributed Training Process
  print("Start Distributed Training..\n")
  t1 = time.time()
  xp.log({"prep_time" : t1-t0})
  maximum_acc_test, maximum_acc_val = 0, 0
  xp.log({"server_val_{}".format(key) : value for key, value in server.evaluate_ensemble().items()})
  test_accs, val_accs = [], []
  for c_round in range(1, hp["communication_rounds"]+1):

    participating_clients = server.select_clients(clients, hp["participation_rate"], hp.get('unbalance_rate', 1), hp.get('sample_mode', "uniform"))
    xp.log({"participating_clients" : np.array([c.id for c in participating_clients])})
    for client in participating_clients:
        client.synchronize_with_server(server)
        train_stats = client.compute_weight_update(hp["local_epochs"], lambda_fedprox=hp["lambda_fedprox"] if "PROX" in hp["aggregation_mode"] else 0.0)

    if hp["aggregation_mode"] == "FedAVG":
        server.fedavg(participating_clients)
    elif hp["aggregation_mode"] == "optimize":
        server.optimize(participating_clients, distill_iter, distill_lr, c_round)
    elif "PROX" in hp["aggregation_mode"]:
        server.fedavg(participating_clients) 
    else:
        import pdb; pdb.set_trace() 
    
    if xp.is_log_round(c_round):
      print("Experiment: {} ({}/{})".format(args.schedule, xp_count+1, n_experiments))   
      
      xp.log({'communication_round' : c_round, 'epochs' : c_round*hp['local_epochs']})
      # wandb.log({'communication_round' : c_round, 'epochs' : c_round*hp['local_epochs']})
      xp.log({key : clients[0].optimizer.__dict__['param_groups'][0][key] for key in optimizer_hp})
      # wandb.log({key : clients[0].optimizer.__dict__['param_groups'][0][key] for key in optimizer_hp})
      if server.weights != None:
        # wandb.log({"weights": np.array(server.weights.cpu())})
        xp.log({"weights": np.array(server.weights.cpu())})
      if "confi" in hp["aggregation_mode"]:
        xp.log({"label_acc": server.label_acc})
        wandb.log({"label_acc": server.label_acc})
      for key, value in server.evaluate_ensemble().items():
        if key == "test_accuracy":
          if value > maximum_acc_test:
            maximum_acc_test = value
            wandb.log({"maximum_acc_{}_a_{}_test".format("accuracy", hp["alpha"]): maximum_acc_test})
        elif key == "val_accuracy":
          if value > maximum_acc_val:
            maximum_acc_val = value
            wandb.log({"maximum_acc_{}_a_{}_val".format("accuracy", hp["alpha"]): maximum_acc_val})
      xp.log({"server_val_{}".format(key) : value for key, value in server.evaluate_ensemble().items()})
      wandb.log({"server_{}_a_{}".format(key, hp["alpha"]) : value for key, value in server.evaluate_ensemble().items()})
      xp.log({"epoch_time" : (time.time()-t1)/c_round})
      stats = server.evaluate_ensemble()
      test_accs.append(stats['test_accuracy'])
      val_accs.append(stats['val_accuracy'])
      # Save results to Disk
      try:
        xp.save_to_disc(path=args.RESULTS_PATH, name=hp['log_path'])
      except:
        print("Saving results Failed!")

      # Timing
      e = int((time.time()-t1)/c_round*(hp['communication_rounds']-c_round))
      print("Remaining Time (approx.):", '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60), e % 60), 
                "[{:.2f}%]\n".format(c_round/hp['communication_rounds']*100))


  # Save model to disk
  server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"])
  # Delete objects to free up GPU memory
  del server; clients.clear()
  torch.cuda.empty_cache()
  run.finish()

def run():
  if args.debugging:
    hpstr = '[{ "random_seed" : [4], "dataset" : ["news20"], "models" : [{"LogisticRegression" : 80}], "attack_rate" :  [0], "attack_method": ["-"], "participation_rate" : [0.4], "alpha" : [0.01], "communication_rounds" : [200], "local_epochs" : [1], "batch_size" : [32],  "val_size" : [256], "val_batch_size": [64], "local_optimizer" : [ ["Adam", {"lr": 0.001}]],"start_optimize":[0], "aggregation_mode" : ["optimize"], "distill_iter": [10], "distill_lr": [1e-2],  "sample_size": [0], "save_scores" : [false], "pretrained" : [null], "save_model" : [null],  "log_frequency" : [1], "log_path" : ["new_noniid/"]}]'
    experiments_raw = json.loads(hpstr)
  else:
    experiments_raw = json.loads(args.hp)
  hp_dicts = [hp for x in experiments_raw for hp in xpm.get_all_hp_combinations(x)][args.start:args.end]
  if args.reverse_order:
    hp_dicts = hp_dicts[::-1]
  experiments = [xpm.Experiment(hyperparameters=hp) for hp in hp_dicts]

  print("Running {} Experiments..\n".format(len(experiments)))
  for xp_count, experiment in enumerate(experiments):
    run_experiment(experiment, xp_count, len(experiments))
 
  
if __name__ == "__main__":
  import wandb
  
  run()
   