import logging
import os
import sys
import json
import torch.distributed as dist
from os.path import dirname, join

from utils.config import Config
from utils.distributed import init_distributed_mode, is_main_process
from utils.logger import setup_logger

logger = logging.getLogger(__name__)


def setup_config():
    """Conbine yaml config and command line config with OmegaConf.
    Also converts types, e.g., `'None'` (str) --> `None` (None)
    """
    config = Config.get_config()
    if config.debug:
        config.wandb.enable = False
    return config


def setup_evaluate_config(config):
    """setup evaluation default settings, e.g., disable wandb"""
    assert config.evaluate
    config.wandb.enable = False
    if config.output_dir is None:
        config.output_dir = join(dirname(config.pretrained_path), "eval")
    return config


def setup_output_dir(output_dir, excludes=["code"]):
    """ensure not overwritting an exisiting/non-empty output dir"""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=False)
    else:
        existing_dirs_files = os.listdir(output_dir)  # list
        remaining = set(existing_dirs_files) - set(excludes)
        remaining = [e for e in remaining if "slurm" not in e]
        remaining = [e for e in remaining if ".out" not in e]
        # assert len(remaining) == 0, f"remaining dirs or files: {remaining}"
        logger.warn(f"remaining dirs or files: {remaining}")


def setup_deepspeed_zero_config(stage):
    # We currently set ZeRO based on stage:
    if stage == 1:
        return {"stage": 1, "reduce_bucket_size": 5e8}
    if stage == 2:
        return {
            "stage": 2,
            "contiguous_gradients": False,
            "overlap_comm": False,
            "reduce_scatter": True,
            "reduce_bucket_size": 5e8,
            "allgather_bucket_size": 5e8,
            "offload_optimizer": {
                "device": "cpu"
            },
        }
    if stage == 3:
        return {
            "stage": 3,
            "contiguous_gradients": True,
            "stage3_max_live_parameters": 1e9,
            "stage3_max_reuse_distance": 1e9,
            "stage3_prefetch_bucket_size": 1e7,
            "stage3_param_persistence_threshold": 1e5,
            "reduce_bucket_size": 1e7,
            "sub_group_size": 1e9,
            "offload_optimizer": {
                "device": "cpu"
            },
            "offload_param": {
                "device": "cpu"
            }
        }
    
    raise ValueError("Wrong stage for deepspeed {}".format(stage.stage))

def setup_deepspeed_config(config):
    config.deepspeed_config = os.path.join(config.output_dir, "deepspeed_config.json")
    opts = config.optimizer
    logger.info(f'Write deepspeed config to {config.deepspeed_config}')
    if not is_main_process():
        return config
    
    os.makedirs(config.output_dir, exist_ok=True)

    with open(config.deepspeed_config, mode="w") as writer:
        ds_config = {
            "train_batch_size": config.batch_size * dist.get_world_size(),
            "train_micro_batch_size_per_gpu": config.batch_size,
            "steps_per_print": 100,
            "optimizer": {
                "type": "Adam",
                "adam_w_mode": True,
                "params": {
                    "lr": opts.lr,
                    "weight_decay": opts.weight_decay,
                    "bias_correction": True,
                    "betas": [
                        opts.opt_betas[0],
                        opts.opt_betas[1],
                    ],
                    "eps": 1e-8
                }
            }
        }
        if config.deepspeed.stage != 0:
            ds_config["zero_optimization"] = setup_deepspeed_zero_config(config.deepspeed.stage)

        if config.use_half_precision:
            if config.get('use_bf16', False):
                ds_config["bf16"] = {
                "enabled": True
                }
            else:
                ds_config["fp16"] = {
                    "enabled": True,
                    "auto_cast": False,
                    "loss_scale": 0,
                    "initial_scale_power": 16,
                    "loss_scale_window": 1000,
                    "hysteresis": 2,
                    "consecutive_hysteresis": False,
                    "min_loss_scale": 1
                } 
        else:
            assert config.deepspeed.stage == 0, "You must use fp16 or bf16 when using ZERO!!!"
            
        # if config.get("max_grad_norm", -1) > 0:
        #     ds_config.update({"gradient_clipping", config.max_grad_norm})
        if opts.get("max_grad_norm", -1) > 0:
            ds_config["gradient_clipping"] = opts.max_grad_norm

        writer.write(json.dumps(ds_config, indent=2))
    
    return config


def setup_main():
    """
    Setup config, logger, output_dir, etc.
    Shared for pretrain and all downstream tasks.
    """
    # try:
    config = setup_config()
    if hasattr(config, "evaluate") and config.evaluate:
        config = setup_evaluate_config(config)
    init_distributed_mode(config)

    if hasattr(config, "deepspeed") and config.deepspeed.enable:
        config = setup_deepspeed_config(config)
    # except Exception as e:
    #     print(f"\033[31m NODE NAME: {os.environ['SLURMD_NODENAME']} is not OK \033[0m")
    #     logger.info(f"NODE NAME: {os.environ['SLURMD_NODENAME']} is not OK")
    #     raise ValueError
    
    if is_main_process():
        setup_output_dir(config.output_dir, excludes=["code"])
        setup_logger(output=config.output_dir, color=True, name="vindlu")
        logger.info(f"config: {Config.pretty_text(config)}")
        Config.dump(config, os.path.join(config.output_dir, "config.json"))
    
    dist.barrier()

    return config
