import os
import time
import numpy as np
from collections import defaultdict
from misc.utils import *

class ParentProcess:
    def __init__(self, args, Server, Client):
        self.args = args
        self.gpus = [int(g) for g in args.gpu.split(',')]
        self.gpu_server = self.gpus[0]
        self.proc_id = os.getpid()
        print(f'main process id: {self.proc_id}')

        self.sd = {'is_done': False, 'quantization_info': {}}

        self.max_iterations = getattr(self.args, 'max_iterations', 5)


        self.create_workers(Client)
        self.server = Server(args, self.sd, self.gpu_server)

    def create_workers(self, Client):
        self.clients = [WorkerProcess(self.args, self.gpu_server, self.sd, Client, worker_id=i) for i in range(self.args.n_clients)]
    

    def start(self):
        self.sd['is_done'] = False
        os.makedirs(self.args.checkpt_path, exist_ok=True)
        os.makedirs(self.args.log_path, exist_ok=True)
        self.n_connected = round(self.args.n_clients * self.args.frac)
    
        total_rounds = self.args.n_rnds * self.max_iterations
        iteration_count = 0
    
        for curr_rnd in range(total_rounds):
            self.curr_rnd = curr_rnd
            np.random.seed(self.args.seed + curr_rnd)
            self.selected = np.random.choice(self.args.n_clients, self.n_connected, replace=False).tolist()
            self.updated = set(self.selected)

            self.server.on_round_begin(curr_rnd)
        
            for client in self.clients:
                if client.worker_id in self.selected:
                    client.client_round([client.worker_id], curr_rnd)
                
            self.server.on_round_complete(self.updated)

            if (curr_rnd + 1) % self.args.n_rnds == 0:
                iteration_count += 1
                print(f"Completed iteration {iteration_count}/{self.max_iterations}")

                if iteration_count >= self.max_iterations:
                    print("Max iterations reached, stopping training.")
                    break
    
        self.sd['is_done'] = True
        return self.server.test_acc

class WorkerProcess:
    def __init__(self, args, gpu_id, sd, Client, worker_id=0, q=None):
        self.q = q
        self.sd = sd
        self.args = args
        self.gpu_id = gpu_id
        self.worker_id = worker_id
        self.client = Client(self.args, self.worker_id, self.gpu_id, self.sd)
        self.client_state = defaultdict(dict)

    def client_round(self, id_list, curr_rnd):
        for client_id in id_list:
            self.client.chance_state(self.client_state, client_id)
            self.client.on_receive_message(curr_rnd)
            self.client.on_round_begin(client_id)
            self.client.update_state(self.client_state, client_id)

            