import ray
import torch
from typing import Optional
from src.log import Checkpoint
from src.apps import App

from flwr.server.server import Server
from flwr.server.client_manager import ClientManager
from flwr.simulation.ray_transport.ray_client_proxy import RayClientProxy

import logging
logger = logging.getLogger(__name__)

def parse_ray_resources(cpus: int, vram: int):
    """ Given the amount of VRAM specified for a given experiment,
        figure out what's the corresponding ration in the GPU assigned
        for experiment. Return % of GPU to use. Here we take into account
        that the CUDA runtime allocates ~1GB upon initialization. We therefore
        substract first that amount from the total detected VRAM. CPU resources
        as returned without modification."""

    gpu_ratio = 0.0
    if torch.cuda.is_available():
        # use that figure to get a good estimate of the VRAM needed per experiment
        # (taking into account ~600MB of just CUDA init stuff)

        # Get VRAM of first GPU
        total_vram = torch.cuda.get_device_properties(0).total_memory

        # convert to MB (since the user inputs VRAM in MB)
        total_vram = float(total_vram)/(1024**2)

        # discard 1GB VRAM (which is roughtly what takes for CUDA runtime)
        # You can verify this yourself by just running:
        # `t = torch.randn(10).cuda()` (will use ~1GB VRAM)
        total_vram -= 1024

        gpu_ratio = float(vram)/total_vram
        logger.info(f"GPU percentage per client: {100*gpu_ratio:.2f} % ({vram}/{total_vram})")

        # ! Limitation: this won't work well if multiple GPUs with different VRAMs are detected by Ray
        # The code above asumes therefore all GPUs have the same amount of VRAM. The same `gpu_ratio` will
        # be used in GPUs #1, #2, etc (even though there won't be 1GB taken by CUDA runtime)
        # TODO: probably we can do something smarter: run a single training batch and monitor the real memory usage. This remove user's input an no longer requiring the user to specify VRAM (which often takes a few rounds of trial-error)
    else:
        logger.warn("No CUDA device found. Disabling GPU usage for Flower clients.")

    # these keys are the ones expected by ray to specify CPU and GPU resources for each
    # Ray Task, representing a client workload.
    return {'num_cpus': cpus, 'num_gpus': gpu_ratio}

def parse_ray_resources_2(cpus: int, gpus: float):
    return {"num_cpus": cpus, "num_gpus": gpus}

def start_simulation( 
    ckp: Checkpoint,
    server: Server,
    app: App,
) -> None:
    sim_config = ckp.config.simulation
    
    max_concurrent = sim_config.get("max_concurrent_clients", 2)
    # Initialize Ray
    ray.init(**sim_config.ray_init_args)
    gpus_per_client = sim_config.ray_init_args.get("num_gpus", 1) / max_concurrent
    # Allocate client resources
    resources = parse_ray_resources_2(sim_config.client_resources.get("num_gpus", 1), gpus_per_client)
    # Register one RayClientProxy object for each client with the ClientManager
    for i in range(sim_config.num_clients):
        temp_client = app.get_client_fn()(str(i))
        print(f"[REGISTER] CID={i}, lid={temp_client.lid}")

        client_proxy = RayClientProxy(
            client_fn=app.get_client_fn(),
            cid=str(i),
            resources=resources,
        )
        client_proxy.lid = getattr(temp_client, "lid", 0)
        client_proxy.value = 1 / sim_config.num_clients
        server.client_manager().register(client=client_proxy)


    app.run(server)

    server.disconnect_all_clients()
