import io
import numpy as np
from mpi4py import MPI
from typing import Any
from .algorithm import *
from torch.optim import *
from logging import Logger

class SchedulerFedAT:
    def __init__(self, comm: MPI.Comm, server: Any, local_steps: int, num_clients: int, num_global_epochs: int, lr: float, speed_ratio: float, logger: Logger):
        self.iter = 0
        self.comm = comm 
        self.server = server
        self.logger = logger
        self.num_clients = num_clients 
        self.num_global_epochs = num_global_epochs
        self.local_steps = local_steps
        self.lr = lr
        self.speed_ratio = speed_ratio
        self.comm_size = comm.Get_size()
        self.client_speed_info = {}
        self.tiering_info = {}
        self.clients_info = {}
        self.arriving_clients = {}

    def _create_tier(self):
        sorted_clients = sorted(self.client_speed_info.items(), key=lambda x:x[1])
        current_tier = -1
        current_min_speed = None
        for client, speed in sorted_clients:
            if current_min_speed is None:
                self.tiering_info[0] = [client]
                self.arriving_clients[0] = []
                self.clients_info[client] = 0
                current_min_speed = speed
                current_tier = 0
                continue
            if speed/current_min_speed <= self.speed_ratio:
                self.tiering_info[current_tier].append(client)
                self.clients_info[client] = current_tier
            else:
                current_tier += 1
                current_min_speed = speed
                self.tiering_info[current_tier] = [client]
                self.clients_info[client] = current_tier
                self.arriving_clients[current_tier] = []
        self.logger.info(f"Tiering information: {self.tiering_info}")
        self.logger.info(f"Client info: {self.clients_info}")

    def speed_record(self, local_model_size: int, client_idx: int, finish_time: float):
        self.client_speed_info[client_idx] = finish_time
        local_model = self._recv_model(local_model_size, client_idx)
        self.server.buffer(local_model)
        if len(self.client_speed_info) == self.num_clients:
            self._create_tier()
            self.server.buffer_update()
        
            tier_client_counts = []
            for i in range(len(self.tiering_info)):
                tier_client_counts.append(len(self.tiering_info[i]))

            self.server.setup_tiers(len(self.tiering_info), tier_client_counts)
            for i in range(self.num_clients):
                self._send_model(i)

    def local_update(self, local_model_size: int, client_idx: int):
        """Schedule update when receive information from one client."""
        local_model = self._recv_model(local_model_size, client_idx)
        self._update(local_model, client_idx)

    def _update(self, local_model: dict, client_idx: int):
        """Update the global model using the local model itself."""
        self.logger.info(f"Client {client_idx} arrives")
        self.iter += 1
        self.validation_flag = False
        # Update the global model
        self.server.model.to("cpu")
        client_tier = self.clients_info[client_idx]
        self.arriving_clients[client_tier].append(client_idx)
        self.server.buffer_tier(local_model, client_tier)
        if len(self.arriving_clients[client_tier]) == len(self.tiering_info[client_tier]):
            self.logger.info(f"Update tier {client_tier}")
            self.server.tier_update(client_tier)
            if self.iter < self.num_global_epochs:
                for i in self.tiering_info[client_tier]:
                    self._send_model(i)
            self.arriving_clients[client_tier] = []
            self.validation_flag = True

    def _recv_model(self, local_model_size: int, client_idx: int):
        local_model_bytes = np.empty(local_model_size, dtype=np.byte)
        self.comm.Recv(local_model_bytes, source=client_idx+1, tag=client_idx+1+self.comm_size)
        local_model_buffer = io.BytesIO(local_model_bytes.tobytes())
        return torch.load(local_model_buffer)

    def _send_model(self, client_idx: int):
        global_model = self.server.model.state_dict()
        # Convert the updated model to bytes
        gloabl_model_buffer = io.BytesIO()
        torch.save(global_model, gloabl_model_buffer)
        global_model_bytes = gloabl_model_buffer.getvalue()
        # Send (buffer size, finish flag) - INFO - to the client in a blocking way
        self.comm.send((len(global_model_bytes), False, self.local_steps, self.lr), dest=client_idx+1, tag=client_idx+1)
        # Send the buffered model - MODEL - to the client in a blocking way
        self.comm.Send(np.frombuffer(global_model_bytes, dtype=np.byte), dest=client_idx+1, tag=client_idx+1+self.comm_size)

