import os
from datetime import timedelta
from pathlib import Path

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from ser.configs import create_workshop, get_config
from ser.engine import Engine

os.environ["WANDB_AGENT_DISABLE_FLAPPING"] = "True"
os.environ["WANDB_AGENT_MAX_INITIAL_FAILURES"] = "20"


def main_worker(local_rank, cfg, mode, world_size, dist_url, i):
    torch.cuda.empty_cache()
    torch.cuda.set_device(local_rank)
    ser.wavlm.utils.environment.set_seed(cfg.train.seed)
    """ To enable deterministic behavior in this case, you must set an environment variable before running 
    your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. """
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

    dist.init_process_group(
        backend="nccl",
        init_method=dist_url,
        timeout=timedelta(seconds=6000),
        world_size=world_size,
        rank=local_rank,
    )

    cfg = create_workshop(cfg, mode, local_rank)
    cfg.train.resume = (
        f"{cfg.train.resume}/val_session_{cfg.dataset.eval_session}_test_speaker"
        f"_{cfg.dataset.test_gender}/checkpoint/model_best_val_top_"
        f"{24 if cfg.model.output_rep == 'elbo' else 1}.pt"
    )
    cfg.ckpt_save_path = str(Path(cfg.train.resume).parent)
    print(f"cfg.ckpt_save_pat: {cfg.ckpt_save_path}")
    # shutil.copy(cfg.train.resume, cfg.ckpt_save_path)
    if not os.path.exists(os.path.join(cfg.workshop, "test_predictions_log_softmax.csv")):
        engine = Engine(cfg, mode, local_rank, world_size)
        engine.prepare_staff()
        engine.evaluate(ith_layer_inference=i, save_model=False)
        engine.test(ith_layer_inference=i)


def main(cfg, mode, i):
    ser.wavlm.utils.environment.set_seed(cfg.train.seed)
    free_port = ser.wavlm.utils.distributed.find_free_port()
    dist_url = f"tcp://127.0.0.1:{free_port}"
    world_size = torch.cuda.device_count()  # num_gpus
    print(f"world_size={world_size} Using dist_url={dist_url}")
    mp.spawn(fn=main_worker, args=(cfg, mode, world_size, dist_url, i), nprocs=world_size)


if __name__ == "__main__":
    mode = "_finetune"
    cfg = get_config(mode=mode)
    main(cfg, mode, i=cfg.model.layer_used_for_inference)
