import os, sys
import argparse, collections, yaml
import random
import logging
import torch.cuda
import pytorch_lightning as pl
import numpy as np

from cpr.llm_transformer.data.lm_datamodule import PlArrowFileModule
from cpr.llm_transformer.pl_module.lm_trainer import LanguageModelTrainer
from cpr.llm_transformer.utils import Config
from cpr.llm_transformer.utils.instantiate import instantiate
from cpr.llm_transformer.utils.folder_manager import get_experiment_folder

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def bold(msg):
    return f"\033[1m{msg}\033[0m"

def main(cfg):
    """
    Launch pretraining
    """

    torch.set_float32_matmul_precision('high')

    if os.environ.get("LOCAL_RANK") is None or os.environ.get("LOCAL_RANK") == 0:
        is_rank_zero = True
        rank = 0
    else:
        is_rank_zero = False
        rank = os.environ.get("LOCAL_RANK")


    os.environ["PL_GLOBAL_SEED"] = str(cfg.train.seed)
    random.seed(cfg.train.seed)
    np.random.seed(cfg.train.seed)
    torch.manual_seed(cfg.train.seed)
    torch.cuda.manual_seed_all(cfg.train.seed)


    if cfg.resume_training:
        exp_folder = get_experiment_folder(**cfg.experiment, new_folder=False, count_folder=False)
        if (exp_folder / "last.ckpt").exists():
            do_resume_training = True
        else:
            do_resume_training = False

    else:
        do_resume_training = False
        exp_folder = get_experiment_folder(**cfg.experiment, new_folder=is_rank_zero)

    if isinstance(cfg.trainer.devices, str):
        cfg.trainer.devices = list(map(int, cfg.trainer.devices.split(",")))
        cfg.rna_data.num_gpu_worker = len(cfg.trainer.devices)

    logger = logging.getLogger(__name__)

    if is_rank_zero:
        cfg.save_config(exp_folder)

        logging.basicConfig(
            format="[%(asctime)s][%(levelname)s][%(name)s] - %(message)s",
            datefmt="%d/%m/%Y %H:%M:%S",
            level=logging.INFO,
            handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(exp_folder / "logfile.txt")],
        )


        logger.info(bold("######################################################"))
        logger.info(bold("########          START   TRAINING          ##########"))
        logger.info(bold("######################################################"))

        logger.info(f"########  Project:    {cfg.experiment.project_name}")
        logger.info(f"########  Session:    {cfg.experiment.session_name}")
        logger.info(f"########  Experiment: {cfg.experiment.experiment_name}")
        logger.info(f"save logs and checkpoints in: {exp_folder.as_posix()}")


        logger.info(bold("############### CONFIGURATION"))
        logger.info("Data args")
        logger.info(cfg.lm_data)
        logger.info("Trainer args")
        logger.info(cfg.trainer)
        logger.info("Train args")
        logger.info(cfg.train)
        logger.info("Optimizer args")
        logger.info(cfg.train.optimizer)
        logger.info("Transformer args")
        logger.info(cfg.transformer)


    logger.info(bold(f"############### LOAD DATA on rank {rank}"))

    cfg_lm_data = {**cfg.lm_data}
    cfg_lm_data["num_gpu_worker"] = cfg.trainer.devices * cfg.trainer.num_nodes
    data_module = PlArrowFileModule(**cfg_lm_data)

    cfg.transformer.seq_vocab_size = data_module.seq_vocab_size
    cfg.transformer.trg_vocab_size = data_module.trg_vocab_size
    cfg.transformer.max_len = cfg.lm_data.max_sample_len
    logger.info(f"#### vocab size: {data_module.seq_vocab_size}")

    model_module = LanguageModelTrainer(
        cfg_train=cfg.train,
        cfg_model=cfg.transformer,
        py_logger=logger,
        val_sets_name=data_module.val_sets_name,
        ignore_index=data_module.ignore_index,
    )


    if is_rank_zero:
        def count_parameters(parameters):
            return sum(p.numel() for p in parameters if p.requires_grad)
        logger.info(f"#### trainable_parameters {count_parameters(model_module.parameters())}")


        def print_model_param_stats(model):
            for idx, (name, params) in enumerate(model.named_parameters()):
                logger.info(f"{idx:03d} {name:70} shape:{str(list(params.shape)):12} mean:{params.mean():8.4f} std:{params.std():8.6f} grad: {params.requires_grad}")
        print_model_param_stats(model_module.model)



    logger.info(f'#### Load logger on rank {rank}')

    training_logger = pl.loggers.tensorboard.TensorBoardLogger(
                        save_dir=exp_folder,
                        name="",
                        version="tb",
                        prefix="",
                        )


    logger.info(f'#### Load strategy on rank {rank}')


    if cfg.trainer.devices == 1:
        strategy = "ddp"
    else:
        strategy = pl.strategies.DDPStrategy(
            find_unused_parameters=True,
            static_graph=False
        )


    logger.info(bold(f"############### TRAINER on rank {rank}"))
    trainer = instantiate(cfg.trainer, instance=pl.Trainer ,
        callbacks=[],
        plugins=[],
        strategy=strategy,
        logger=training_logger,
    )

    logger.info(f"Starting training on rank {rank}")

    trainer.fit(
        model=model_module, datamodule=data_module,  ckpt_path=exp_folder / 'last.ckpt' if do_resume_training else None,
    )

    logger.info(f"Finished saving model weights on rank {rank}")
    logger.info("End training!")


if __name__ == "__main__":

    from functools import reduce  # forward compatibility for Python 3
    import operator

    def update(d, u):
        for k, v in u.items():
            if isinstance(v, collections.abc.Mapping):
                d[k] = update(d.get(k, {}), v)
            else:
                d[k] = v
        return d


    def getFromDict(dataDict, mapList):
        return reduce(operator.getitem, mapList, dataDict)

    def setInDict(dataDict, mapList, value):
        getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value

    def convert_string_value(value):
        if value in ('false', 'False'):
            value = False
        elif value in ('true', 'True'):
            value = True
        else:
            try:
                value = int(value)
            except:
                try:
                    value = float(value)
                except:
                    pass
        return value


    print("CUDA AVAILABLE", torch.cuda.is_available())
    print("CUDA DEVICES", torch.cuda.device_count())

    default_config_name = "test_config.yaml"

    parser = argparse.ArgumentParser(description='Train GPT2 LLM')
    parser.add_argument('-c', '--config', type=str, default=default_config_name, help='config file name')

    args, unknown_args = parser.parse_known_args()

    config_name = args.config
    if not config_name.endswith('.yaml'):
        config_name += '.yaml'

    config_file = os.path.join("config", args.config)
    with open(config_file, 'r') as f:
        config_dict = yaml.load(f, Loader=yaml.Loader)

    for arg in unknown_args:
        if  '=' in arg:
            keys = arg.split('=')[0].split('.')
            value = convert_string_value(arg.split('=')[1])
            print(keys, value)
            setInDict(config_dict, keys, value)
        else:
            raise UserWarning(f"argument unknown: {arg}")

    config = Config(config_dict=config_dict)


    print(config.experiment.project_name,2)


    main(cfg=config)
