import argparse
import os
from pathlib import Path

from offline_evaluation.load_datasets import get_data_loaders
from models.ssl_models import batch_size as total_batch_size
from utils.model_checkpointing import AlteredModelCheckpoint

import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
import time
import pytorch_lightning as pl
from models.ssl_models import *
import yaml


def load_config_dict(path: str):
    with open(path, 'r') as config_file:
        return yaml.full_load(config_file)


def print_dict(d):
    for key, value in d.items():
        print(f"{key}: {value}")


def run_model_train(benchmark_model, dataloader_train_ssl, dataloader_train_eval,
                    dataloader_test_eval, resnet_type=None, depth=5, width=100,
                    output_folder=None, max_epochs=1000):
    # use a GPU if available
    gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0

    if distributed:
        distributed_backend = 'ddp'
        # reduce batch size for distributed training
        batch_size = total_batch_size // gpus
    else:
        distributed_backend = None
        # limit to single gpu if not using distributed training
        gpus = min(gpus, 1)

    bench_results = dict()
    runs = []
    step = 20
    model_name = benchmark_model.__name__.replace('Model', '')
    conf = {'model': benchmark_model, 'model_name': model_name,
            'batch_size': batch_size, 'width': width, 'depth': depth, 'resnet_type': resnet_type}

    print("*********************************************************")
    print_dict(conf)

    logs_path = os.path.join(output_folder, "tb_logs")
    print(f"logging to:{logs_path}")
    logger = CSVLogger(output_folder, name=f"ssl")
    logger.log_hyperparams(conf)

    pl.seed_everything(0)

    benchmark_model = benchmark_model(dataloader_kNN=dataloader_train_eval,
                                      dataloader_test=dataloader_test_eval,
                                      num_classes=classes,
                                      resnet_type=resnet_type,
                                      depth=depth, width=width,
                                      max_epochs=max_epochs,
                                      logger=logger)

    print(f"checkpoints folder:{os.path.join(output_folder, 'checkpoints')}")
    checkpoints_folder = os.path.join(output_folder, 'checkpoints')
    Path(checkpoints_folder).mkdir(parents=True, exist_ok=True)

    # checkpoint_callback = pl.callbacks.ModelCheckpoint(
    #     dirpath=os.path.join(output_folder, 'checkpoints'),
    #     save_top_k=1,
    #     save_last=True,
    #     verbose=True,
    #     period=step,
    #     monitor=None)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=os.path.join(output_folder, 'checkpoints'),
        save_top_k=-1, every_n_epochs=step,
        monitor='Loss/total_loss')

    # checkpoint_callback = AlteredModelCheckpoint(
    #     save_on_train_epoch_end=False,
    #     dirpath=checkpoints_folder,
    #     save_top_k=-1, every_n_epochs=step,
    #     monitor='Loss/total_loss',
    #     # every_n_val_epochs=step
    # )

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        # max_epochs=3,
        accelerator='dp',
        gpus=gpus,
        default_root_dir=logs_path,
        logger=logger,
        callbacks=[checkpoint_callback],
        check_val_every_n_epoch=step,
    )

    start = time.time()
    trainer.fit(
        benchmark_model,
        train_dataloaders=dataloader_train_ssl,
        val_dataloaders=dataloader_test_eval
    )
    end = time.time()
    run = {
        'model': model_name,
        'batch_size': batch_size,
        'epochs': max_epochs,
        'max_accuracy': benchmark_model.max_knn_accuracy,
        'runtime': end - start,
        'gpu_memory_usage': torch.cuda.max_memory_allocated(),
        'seed': 0,
    }
    print_dict(run)

    runs.append(run)
    print(run)

    # delete model and trainer + free up cuda memory
    del benchmark_model
    del trainer
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    bench_results[model_name] = runs


def main(task, output_folder):
    import os
    from offline_evaluation.load_datasets import load_datasets


    new_output_folder = os.path.join(output_folder, task)
    Path(new_output_folder).mkdir(parents=True, exist_ok=True)
    dataloader_train_ssl, dataloader_train_eval, dataloader_test_eval = load_datasets(augment=True)

    benchmark_model = VICRegModel
    # benchmark_model = SimCLRModel

    run_model_train(benchmark_model, dataloader_train_ssl, dataloader_train_eval,
                    dataloader_test_eval, width=250, depth=10,
                    output_folder=new_output_folder, max_epochs=1000)


def run_main():
    eval_types = ['ssl_train']
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, help='task type', required=True, choices=eval_types)
    parser.add_argument('--output_folder', type=str, required=True)
    args = parser.parse_args()

    main(args.task, args.output_folder)


if __name__ == '__main__':
    run_main()
