import os
from datetime import timedelta

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from ser.configs import create_workshop, get_config
from ser.engine import Engine


def main_worker(local_rank, cfg, mode, world_size, dist_url):
    torch.cuda.empty_cache()
    torch.cuda.set_device(local_rank)
    ser.wavlm.utils.environment.set_seed(cfg.train.seed)
    """ To enable deterministic behavior in this case, you must set an environment variable before running 
    your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. """
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    # os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
    # os.environ["NCCL_IB_GID_INDEX"] = "3"
    dist.init_process_group(
        backend="nccl",
        init_method=dist_url,
        timeout=timedelta(seconds=6000),
        world_size=world_size,
        rank=local_rank,
    )
    cfg = create_workshop(cfg, mode, local_rank)
    engine = Engine(cfg, mode, local_rank, world_size)
    engine.run()

    # if local_rank == 0:
    #     criterion = ["accuracy", "precision", "recall", "F1"]
    #     evaluate = cfg.dataset.evaluate
    #     outfile = f"{cfg.workshop}/result/result_{cfg.model.type}_Finetune.csv"
    #     wantlow = False
    #     return_epoch = 0  # -1 if cfg.dataset.have_test_set else None
    #     utils.collect_result.path_to_csv(
    #         cfg.workshop,
    #         criterion,
    #         evaluate,
    #         csvfile=outfile,
    #         logname="val.log",
    #         wantlow=wantlow,
    #         epoch=return_epoch,
    #     )


def main(cfg, mode):
    # device_id = device_id if device_id is not None else cfg.train.device_id
    # utils.environment.visible_gpus(device_id)
    ser.wavlm.utils.environment.set_seed(cfg.train.seed)
    free_port = ser.wavlm.utils.distributed.find_free_port()
    dist_url = f"tcp://127.0.0.1:{free_port}"
    world_size = torch.cuda.device_count()  # num_gpus
    print(f"world_size={world_size} Using dist_url={dist_url}")
    mp.spawn(fn=main_worker, args=(cfg, mode, world_size, dist_url), nprocs=world_size)


if __name__ == "__main__":
    mode = "_finetune"
    cfg = get_config(mode=mode)
    main(cfg, mode)
