import glob
import torch
import gcip.utils.io as playbook_io
import gcip.utils.load as playbook_load
import wandb
from gcip.config import *
from gcip.utils.constants import Cte

os.environ['WANDB_NOTEBOOK_NAME'] = 'name_of_the_notebook'

args_list, args = parse_args()

load_model = isinstance(args.load_model, str)
if load_model:
    playbook_io.print_info(f"Loading model: {args.load_model}")

config = build_config(config_file=args.config_file,
                      args_list=args_list,
                      config_default_file=args.config_default_file)

assert_cfg_and_config(cfg, config)

# print_config(config)

if cfg.device in ['cpu']:
    os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
playbook_load.set_reproducibility(cfg)

if cfg.device == "auto" or cfg.device == "gpu":
    if torch.cuda.is_available():
        cfg.device = "cuda:0"
    else:
        cfg.device = "cpu"

preparator = playbook_load.load_preparator(cfg=cfg, prepare_data=True)

loaders = preparator.get_dataloaders(batch_size=cfg.train.batch_size,
                                     num_workers=cfg.train.num_workers)

for i, loader in enumerate(loaders):
    playbook_io.print_info(f"[{i}] num_batchees: {len(loader)}")

model = playbook_load.load_model(cfg=cfg,
                                 preparator=preparator)

param_count = model.param_count()
config['param_count'] = param_count

if not load_model:
    assert isinstance(args.project, str)
    run = wandb.init(mode=args.wandb_mode,
                     group=args.wandb_group,
                     project=args.project,
                     config=config
                     )

    import uuid

    if args.wandb_mode != 'disabled':
        run_uuid = run.id
    else:
        run_uuid = str(uuid.uuid1()).replace('-', '')
else:
    run_uuid = os.path.basename(args.load_model)

# # # Here you can add many features in your Trainer: such as num_epochs,  gpus used, clusters used etc.

dirpath = os.path.join(cfg.root_dir, run_uuid)

if load_model:
    from datetime import datetime

    now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    logger_dir = os.path.join(cfg.root_dir, run_uuid, 'evaluate', now)
else:
    logger_dir = os.path.join(cfg.root_dir, run_uuid)

trainer, logger = playbook_load.load_trainer(cfg=cfg,
                                             dirpath=dirpath,
                                             logger_dir=logger_dir,
                                             include_logger=True,
                                             model_checkpoint=cfg.train.model_checkpoint,
                                             cfg_early=cfg.early_stopping,
                                             preparator=preparator)

playbook_io.print_info(f'Experiment folder: {logger.save_dir}\n\n')

wandb_local.log_config(dict(config), root=logger.save_dir)

if cfg.train.auto_lr_find and not load_model:
    # Run learning rate finder
    lr_finder = trainer.tuner.lr_find(model, train_dataloaders=loaders[0])

    # Results can be found in lr_finder.results

    # # Plot with
    # fig = lr_finder.plot(suggest=True)
    # fig.show()

    # Pick point based on plot, or get suggestion
    new_lr = lr_finder.suggestion()

    model.optim_config.base_lr = new_lr

if not load_model:
    wandb_local.copy_config(config_default=Cte.DEFAULT_CONFIG_FILE,
                            config_experiment=args.config_file,
                            root=logger.save_dir)
    trainer.fit(model,
                train_dataloaders=loaders[0],
                val_dataloaders=loaders[1])

if isinstance(preparator.single_split, str):
    loaders = [loaders[0]]

model.save_dir = dirpath

if load_model:
    ckpt_name_list = glob.glob(os.path.join(args.load_model, f"*ckpt"))
    for ckpt_file in ckpt_name_list:
        model = playbook_load.load_model(cfg=cfg,
                                         preparator=preparator,
                                         ckpt_file=ckpt_file
                                         )
        model.eval()
        model.save_dir = dirpath
        ckpt_name = preparator.get_ckpt_name(ckpt_file)
        for i, loader_i in enumerate(loaders):
            s_name = preparator.split_names[i]
            playbook_io.print_info(f'Testing {s_name} split')
            preparator.set_current_split(i)
            model.ckpt_name = ckpt_name
            _ = trainer.test(model=model, dataloaders=loader_i)
            metrics_stats = model.metrics_stats
            metrics_stats['current_epoch'] = trainer.current_epoch
            wandb_local.log_v2({s_name: metrics_stats, 'epoch': ckpt_name},
                               root=trainer.logger.save_dir)


else:

    ckpt_name_list = ['last']
    if cfg.early_stopping.activate:
        ckpt_name_list.append('best')
    for ckpt_name in ckpt_name_list:
        for i, loader_i in enumerate(loaders):
            s_name = preparator.split_names[i]
            playbook_io.print_info(f'Testing {s_name} split')
            preparator.set_current_split(i)
            model.ckpt_name = ckpt_name
            _ = trainer.test(ckpt_path=ckpt_name, dataloaders=loader_i)
            metrics_stats = model.metrics_stats
            metrics_stats['current_epoch'] = trainer.current_epoch

            wandb_local.log_v2({s_name: metrics_stats, 'epoch': ckpt_name},
                               root=trainer.logger.save_dir)

    run.finish()
    if args.delete_ckpt:
        for f in glob.iglob(os.path.join(logger.save_dir, '*.ckpt')):
            playbook_io.print_warning(f"Deleting {f}")
            os.remove(f)

# print(cfg)

print(f'Experiment folder: {logger.save_dir}')
