import os
import copy
import sys
import yaml
import logging
import argparse

sys.path.append("../vccm")
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.distributed as dist
from tensorboardX import SummaryWriter
from typing import Any, Dict, Tuple, Union

from datetime import date
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel

from nn_ss.dataloader.semantic_dataset import SemanticDataset_rewrite

from nn_ss.trainer.scheduler import LRScheduler
from nn_ss.trainer.soundstorm.executor import Executor


from nn_ss.sound_synthesis.modeling.build import build_model
from nn_ss.sound_synthesis2.utils.io import load_yaml_config

from utils.checkpoint import save_checkpoint, load_optimizer,load_checkpoint
from utils.common import load_dict, load_json, save_dict


def get_args():
    parser = argparse.ArgumentParser(description='Generated Pretrained Language Model')
    parser.add_argument('--config',
                        type=str,
                        default="config/styletrolnet.yaml",
                        help='location of config yaml')

    parser.add_argument('--seed',
                        type=int,
                        default=1111,
                        help='random seed')
    parser.add_argument('--gpu',
                        type=int,
                        default=8,
                        help='gpu id for this local rank, -1 for cpu')

    parser.add_argument('--num_workers',
                        default=4,
                        type=int,
                        help='num of subprocess workers for reading')
    parser.add_argument('--pin_memory',
                        action='store_true',
                        default=False,
                        help='Use pinned memory buffers used for reading')
    parser.add_argument('--prefetch',
                        default=64,
                        type=int,
                        help='prefetch number')
    
    parser.add_argument('--athena_dist',
                        action='store_true',
                        default=False,
                        help='Using Athena for distributed training')

    parser.add_argument('--use_amp',
                        action='store_true',
                        default=False,
                        help='Use automatic mixed precision training')
    parser.add_argument('--fp16_grad_sync',
                        action='store_true',
                        default=False,
                        help='Use fp16 gradient sync for ddp')

    parser.add_argument('--ddp.rank',
                        dest='rank',
                        default=0,
                        type=int,
                        help='global rank for distributed training')
    parser.add_argument('--ddp.world_size',
                        dest='world_size',
                        default=-1,
                        type=int,
                        help='''number of total processes/gpus for
                        distributed training''')
    parser.add_argument('--ddp.dist_backend',
                        dest='dist_backend',
                        default='nccl',
                        choices=['nccl', 'gloo'],
                        help='distributed backend')
    parser.add_argument('--ddp.init_method',
                        dest='init_method',
                        default=None,
                        help='ddp init method')
    parser.add_argument('--logdir',
                        default="jsp",
                        type=str,
                        help="print log")
    
    args = parser.parse_args()
    return args


def set_logger(log_dir):
    # log_dir = args.logdir
    os.makedirs(log_dir, exist_ok=True)

    log_file = os.path.join(log_dir, f"lm_train-{str(date.today())}.log")
    # log_file = None
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s',
                        filename=log_file,
                        filemode='a')


def init_dist(args: Any,
              configs: Dict) -> Tuple[Any, Dict]:
    if args.athena_dist:
        distributed = True
        backend = "nccl"
        args.world_size = int(os.environ.get('WORLD_SIZE', 1))
        args.rank = int(os.environ.get('RANK', 0))

        if distributed:
            logging.info('training on multiple gpus, this gpu {}'.format(args.gpu))
            logging.info('Using distributed PyTorch with {} backend'.format(backend))
            dist.init_process_group(backend)

        logging.info("dist is_available: {}".format(dist.is_available()))
        logging.info("dist is_initialized: {}".format(dist.is_initialized()))
        logging.info("cuda available: {}".format(torch.cuda.is_available()))
        logging.info("world_size: {}".format(args.world_size))
        logging.info("rank: {}".format(args.rank))
    
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
        distributed = args.world_size > 1 and dist.is_available()
        if distributed:
            logging.info('training on multiple gpus, this gpu {}'.format(args.gpu))
            dist.init_process_group(args.dist_backend,
                                    init_method=args.init_method,
                                    world_size=args.world_size,
                                    rank=args.rank)
    configs['distributed'] = distributed


def load_model_to_cuda(model: torch.nn.Module,
                       args: Any,
                       configs: Dict,
                       distributed: bool=False,
                       find_unused_parameters: bool=True
                       ) -> Tuple[Union[DistributedDataParallel, torch.nn.Module], Dict, torch.device]:
    if distributed:
        assert torch.cuda.is_available()
        # cuda model is required for nn.parallel.DistributedDataParallel
        model.cuda()
        model = DistributedDataParallel(model, find_unused_parameters=find_unused_parameters)
        device = torch.device('cuda')
        if args.fp16_grad_sync:
                from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
                model.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
    else:
        use_cuda = args.gpu >= 0 and torch.cuda.is_available()
        device = torch.device('cuda' if use_cuda else 'cpu')
        model = model.to(device)

    configs['total_param'] = sum([p.nelement() for p in model.parameters()])
    configs['train_param'] = sum([p.nelement() for p in model.parameters() if p.requires_grad == True])

    # breakpoint()

    logging.info(" | Total Parameters: {}".format(configs["total_param"]))
    logging.info(" | Train Parameters: {}".format(configs["train_param"]))

    return model, configs, device


def main():
    args = get_args()
    print("args.config:",args.config)
    assert os.path.exists(args.config)

    configs = load_yaml_config(args.config)
    # configs = merge_opts_to_config(configs, args.opts) # 合并命令行输入到config文件中


    set_logger(args.logdir)
    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    
    init_dist(args, configs)
    
    model_type = configs["model_type"]
    

    train_corpus = configs["train_corpus"]
    valid_corpus = configs.get("valid_corpus")
    
    checkpoint_dir = configs["checkpoint_dir"]
    tensorboard_dir = configs["tensorboard_dir"]
    
    pretrain_checkpoint = configs.get("pretrained_checkpoint", None)

    if args.rank == 0:
        os.makedirs(checkpoint_dir, exist_ok=True)
        os.makedirs(tensorboard_dir, exist_ok=True)
    logging.info(f" | Train: {train_corpus}")
    logging.info(f" | Valid: {valid_corpus}")
    assert os.path.exists(train_corpus)


    distributed = configs.get('distributed', False)
    start_epoch = int(configs.get('start_epoch', 0))
    num_epoches = int(configs.get('max_epoch', 40))
    num_checkpoints_to_average = int(configs.get('avg_ckpt_num', 5))

    train_dataset = SemanticDataset_rewrite(folder=train_corpus,
                 num_quant = configs["train_stage_nq"],
                 stage=configs["train_stage"],
                 max_frames_in_batch = int(configs["max_frames_in_batch"]))
    dev_dataset = SemanticDataset_rewrite(folder=valid_corpus,
                 num_quant = configs["valid_stage_nq"],
                 stage=configs["valid_stage"],
                 max_frames_in_batch=int(configs["max_frames_in_batch"]),
                 partition = False)
    # print("11111:",args.pin_memory)
    # print("jsp_tmp",train_dataset)
    # assert 0
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=None,
                                   pin_memory=args.pin_memory,
                                   num_workers=args.num_workers,
                                   prefetch_factor=args.prefetch)
    dev_data_loader = DataLoader(dev_dataset,
                                 batch_size=None,
                                 pin_memory=args.pin_memory,
                                 num_workers=args.num_workers,
                                 prefetch_factor=args.prefetch) if valid_corpus else None
    
    # # init lm model
    # print("jsp_tmp")
    # print(train_data_loader)
    # print(train_data_loader[0])
    # assert 0

    # get logger
    # logger = Logger(args)
    # logger.save_config(configs)

    # get model
    model = build_model(configs, args)
    if pretrain_checkpoint is not None:
        load_checkpoint(model,
                        path=os.path.abspath(pretrain_checkpoint))

    model, configs, device = load_model_to_cuda(
        model=model,
        args=args,
        configs=configs,
        distributed=distributed,
        find_unused_parameters=True
        )
    
    # add train info to configs
    configs['start_epoch'] = start_epoch
    configs['rank'] = args.rank
    
    # configs['train_corpus_length'] = len(train_dataset)
    # logging.info(" | Train Corpus Length: {}".format(configs['train_corpus_length']))

    # save configs and dicts
    if args.rank == 0:
        saved_config_path = os.path.join(checkpoint_dir, "train.yaml")
        with open(saved_config_path, 'w') as f_save:
            saved_configs = yaml.dump(configs)
            f_save.write(saved_configs)
        saved_checkpoint = os.path.join(checkpoint_dir, '{}_init.pt').format(model_type)
        save_checkpoint(model, saved_checkpoint)
        

    optim = configs['optim']
    optim_conf = configs['optim_conf']
    scheduler_conf = configs['scheduler_conf']
    pretrain_optimizer = configs.get("pretrained_optimizer", None)
    
    train_parameters = [p for p in model.parameters() if p.requires_grad]
    # freezed_parameters = [p for p in gpt_model.parameters() if not p.requires_grad]
    optimizer = getattr(torch.optim, optim)(train_parameters, **optim_conf)
    if pretrain_optimizer is not None:
        load_optimizer(optimizer, pretrain_optimizer)

    optim_scheduler = LRScheduler(optimizer=optimizer,
                                  base_lr=optim_conf['lr'],
                                  start_epoch=start_epoch,
                                  **scheduler_conf)
    
    writer = None
    if args.rank == 0:
        exp_id = os.path.basename(checkpoint_dir)
        writer = SummaryWriter(os.path.join(tensorboard_dir, exp_id))
    scaler = None
    if args.use_amp:
        scaler = torch.cuda.amp.GradScaler()

    ss_executor = Executor(model=model,
                           model_type=model_type,
                           checkpoint_dir=checkpoint_dir,
                           scheduler=optim_scheduler,
                           num_checkpoints_to_average=num_checkpoints_to_average,
                           writer=writer,
                           scaler=scaler)
    
    final_epoch = None
    for epoch in range(start_epoch + 1, num_epoches + 1):
        train_dataset.set_epoch(epoch)
        ss_executor.epoch(epoch=epoch,
                          train_dataloader=train_data_loader,
                          dev_dataloader=dev_data_loader,
                          parameters=train_parameters,
                          configs=configs,
                          device=device)
        final_epoch = epoch

        if optim_scheduler.is_early_stop:
            logging.warn('| Epoch {} | loss not decay for 10 epochs, early stop training'.format(final_epoch))
    
    if final_epoch and args.rank == 0:
        best_epoch, best_ppl = optim_scheduler.topk_list[0]
        checkpoint_dir = os.path.abspath(checkpoint_dir)

        best_model_link = os.path.join(checkpoint_dir, "{}_epoch_{}_ppl_{:.3f}.pt".format(
            model_type, best_epoch, best_ppl))
        best_model_path = os.path.join(checkpoint_dir, '{}_{}.pt'.format(model_type, final_epoch))
        os.symlink(best_model_path, best_model_link)
        writer.close()

    if distributed:
        # Tear down the process group
        dist.destroy_process_group()
        # dist.barrier()


if __name__ == '__main__':
    main()
