import argparse
import torch
import wandb
import tqdm
from lightly.data import LightlyDataset

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from src.util.losses import *
from src.util.models import *
from src.datasets.Visualization import *
from src.datasets.CIFAR10LightlyModule import CIFAR10LightlyModule
from src.datasets.CIFAR100LightlyModule import CIFAR100LightlyModule
from src.datasets.STL10LightlyModule import STL10LightlyModule
from src.datasets.ImageNetLightlyModule import ImageNetLightlyModule

lossname_to_lossmodule = {
    'frossl': FroSSL_Loss,
    'barlow': Barlow_Twins_Loss,
    'vicreg': VICREG_Loss,
    'tico': TiCo_Loss,
    'logdet': CorInfoMax
}

dataset_to_datamodule = {
    'cifar10': CIFAR10LightlyModule,
    'cifar100': CIFAR100LightlyModule,
    'stl10': STL10LightlyModule,
    'imagenet': ImageNetLightlyModule
}

def bookkeeping(args):
    pl.seed_everything(args.seed)

    try:
        torch.set_float32_matmul_precision('high')
    except:
        print("Could not change float32 precision! Not a problem unless device is a TPU")

    # set experiment name if not set
    if args.experiment_name is None:
        args.experiment_name = f'{args.criterion.__name__}_{args.dataset}_{args.backbone}_{args.projection}'

def parse():
    parser = argparse.ArgumentParser()

    #pytorch lightning arguments
    parser.add_argument('--devices', type=int, default=1)
    parser.add_argument('--num_nodes', type=int, default=1)
    parser.add_argument('--num-workers', type=int, default=8)

    # training arguments
    parser.add_argument('--experiment_name', default=None, type=str)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--ssl-epochs', type=int, default=500)
    parser.add_argument('--probe-epochs', type=int, default=100)
    parser.add_argument('--dataset', type=str, choices=["cifar10", "cifar100", "stl10", "imagenet"], default="cifar10")
    parser.add_argument('--batch-size', type=int, default=1024)
    parser.add_argument('--probe-batch-size', type=int, default=1024)
    parser.add_argument('--log-eigenvalues', default=False, action=argparse.BooleanOptionalAction)

    # knn validation arguments
    parser.add_argument('--knn_k', type=int, default=1)
    parser.add_argument('--knn_t', type=float, default=0.1)
    parser.add_argument('--knn_chunk_size', type=float, default=16)

    # architecture arguments
    parser.add_argument('--criterion', type=lambda s: lossname_to_lossmodule[s])
    parser.add_argument('--backbone',  type=str, choices=["resnet18", "resnet50"], default="resnet18")
    parser.add_argument('--projection', type=str, choices=["mlp", "identity", "linear"], default="mlp")
    parser.add_argument('--projector-dim', type=int, default=1024)
    parser.add_argument('--online-classifier', default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument('--small-init', default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument('--small-init-factor', type=float, default=0.1)

    # optimizer arguments
    parser.add_argument('--ssl-optimizer', type=str, default="SGD", choices=["SGD", "Adam"])
    parser.add_argument('--ssl-learning-rate', type=float, default=1e-2)
    parser.add_argument('--ssl-weight-decay', type=float, default=5e-4)
    parser.add_argument('--ssl-lr-scheduler', type=str, default="cosinewarmlr", choices=["cosineannealing", "cosinewarmlr", "steplr"])
    parser.add_argument('--ssl-warm-restart-epochs', type=float, default=10)
    parser.add_argument('--probe-optimizer', type=str, default="Adam", choices=["SGD", "Adam"])
    parser.add_argument('--probe-learning-rate', type=float, default=5e-3)

    # logging arguments
    parser.add_argument('--logs-folder', type=str, default="logs/")
    parser.add_argument('--model-checkpoint-folder', type=str, default="checkpoints/")
    parser.add_argument('--data-folder', type=str, default="data/")
    parser.add_argument('--wandb', default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument('--resume_checkpoint_path', type=str, default=None)
    parser.add_argument('--resume_make_fresh_trainer', default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument('--wandb_resume_version', type=str, default=None)

    # loss arguments
    parser.add_argument('--frossl_kernel_type', type=str, choices=["linear", "gaussian"], default="linear")
    parser.add_argument('--frossl_alpha', type=float, default=2)
    parser.add_argument('--frossl_rank_calculator',  type=str, choices=["nuclear", "frobenius"], default="frobenius")
    parser.add_argument('--barlow_lambda',  type=float, default=5e-3)

    return parser.parse_args()

def create_wandb_logger(args):
    # use wandb as metrics logger if specified
    if args.wandb:
        wandb.login()

        if args.wandb_resume_version:
            return WandbLogger(project='ssl', name=args.experiment_name, log_model=False, version=args.wandb_resume_version, resume="must")

        else:
            return WandbLogger(project='ssl', name=args.experiment_name, log_model=False)
        

def train_ssl(args, loggers):
    data_module = dataset_to_datamodule[args.dataset](args, is_ssl_run=True, logger=loggers[1] if args.wandb else None)

    # get version number from tensorboard logger
    version = loggers[0].version.split('_')[-1] if isinstance(loggers[0].version, str) else loggers[0].version

    # make callbacks
    checkpoint_path = f'{args.model_checkpoint_folder}{args.experiment_name}/version_{version}'
    every_n = 1 if args.dataset == "imagenet" else 10
    checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path, 
            save_on_train_epoch_end=True,
            save_top_k=-1, monitor="val_online_cls_top1", filename='{epoch}-{train_loss:.2f}-{val_online_cls_top1:.2f}',
            every_n_epochs=every_n, mode='max')
    lr_monitor = LearningRateMonitor(logging_interval='step')

    trainer = pl.Trainer(devices=args.devices, num_nodes=args.num_nodes, 
                        accelerator='gpu', strategy='ddp_find_unused_parameters_false', max_epochs=args.ssl_epochs,
                        callbacks=[checkpoint_callback, lr_monitor], logger=loggers, log_every_n_steps=10)
    
    if args.resume_checkpoint_path is None:
        # make model
        model = SSL_Model(args, num_classes=data_module.num_classes)

        # start fresh training
        print(f'Beginning training of {args.experiment_name}')
        trainer.fit(model, datamodule=data_module)
    else:
        if args.resume_make_fresh_trainer:
            print(f'Resuming training of {args.experiment_name} with fresh trainer')
            # make model from checkpoint
            model = SSL_Model.load_from_checkpoint(checkpoint_path=args.resume_checkpoint_path, 
                                                args=args, num_classes=data_module.num_classes)

            trainer.fit(model, datamodule=data_module)
        else:
            print(f'Resuming training of {args.experiment_name}')
            model = SSL_Model(args, num_classes=data_module.num_classes)

            # make trainer from checkpoint
            trainer.fit(model, datamodule=data_module, ckpt_path=args.resume_checkpoint_path)

    # return current backbone, not necessarily backbone with best knn accuracy
    return model

# converts a dataset of images to a dataset of backbone embeddings
def convert_image_dataset_to_embedding_dataset(ssl_model, desc, dataloader):
    embeddings = []
    labels = []

    with torch.no_grad():
        ssl_model.eval()

        for batch in tqdm.tqdm(dataloader, desc=f"Converting {desc} set to embeddings"):
            x, y, _ = batch
            x = x.cuda(non_blocking=True)
            embedding = ssl_model.backbone(x)

            embeddings.extend(embedding.cpu().detach().squeeze().numpy())
            labels.extend(y.cpu().detach().squeeze().numpy())

    
    embeddings = torch.tensor(np.array(embeddings))
    labels = torch.tensor(np.array(labels), dtype=torch.int64)
    tensor_dataset = torch.utils.data.TensorDataset(embeddings, labels)

    return LightlyDataset.from_torch_dataset(tensor_dataset), embeddings, labels

@rank_zero_only
def train_probe(args, ssl_model, loggers):
    # get version number from tensorboard logger
    version = loggers[0].version.split('_')[-1] if isinstance(loggers[0].version, str) else loggers[0].version 

    # make callbacks
    checkpoint_path = f'{args.model_checkpoint_folder}{args.experiment_name}/version_{version}'
    checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path, 
            save_on_train_epoch_end=True,
            save_top_k=1, monitor="validation_accuracy", filename='linprobe-{epoch}-{train_loss:.2f}-{validation_accuracy:.2f}',
            every_n_epochs=1, mode='max')
    lr_monitor = LearningRateMonitor(logging_interval='step', )

    data_module = dataset_to_datamodule[args.dataset](args, is_ssl_run=False, logger=loggers[1] if args.wandb else None)

    # precompute the image embeddings for the linear probe
    ssl_model.cuda()
    
    if data_module.has_test_set:
        data_module.setup(stage="test")
        test_embeddings, _, _ = convert_image_dataset_to_embedding_dataset(ssl_model, "test", data_module.test_dataloader())

    data_module.setup(stage="fit")
    train_embeddings, train_embeds, train_labels = convert_image_dataset_to_embedding_dataset(ssl_model, "train", data_module.train_dataloader())
    val_embeddings, val_embeds, val_labels = convert_image_dataset_to_embedding_dataset(ssl_model, "val", data_module.val_dataloader())

    ssl_model.cpu()
    
    # test full knn accuracy on train and val
    # split train embeds into chunks so that we don't run out of memory on big datasets like imagenet
    total_preds, correct_preds = 0, 0
    val_embeds = F.normalize(val_embeds, dim=1)
    train_embeds = F.normalize(train_embeds, dim=1)
    for chunk_indices in tqdm.tqdm(np.array_split(range(val_embeds.shape[0]), args.knn_chunk_size), 
                                   desc="full knn accuracy"):
        val_embeds_chunk = val_embeds[chunk_indices, :]
        val_labels_chunk = val_labels[chunk_indices]

        knn_preds = knn_predict(val_embeds_chunk, train_embeds.t(), train_labels.t(), data_module.num_classes, args.knn_k, args.knn_t)
        top1_correct = (knn_preds[:, 0] == val_labels_chunk).float().sum().item()
        total_preds += knn_preds.shape[0]
        correct_preds += top1_correct

    top1_accuracy = correct_preds / total_preds
    print(f"Full kNN accuracy on validation set: {top1_accuracy}")
    if args.wandb:
        loggers[1].log_metrics({'full_kNN_accuracy': top1_accuracy})

    # make dataloaders for embeddings
    common_options = {
        'num_workers': args.num_workers,
        'pin_memory': True,
        'batch_size': args.probe_batch_size,
        'drop_last': False
    }
    dataloader_train = torch.utils.data.DataLoader(train_embeddings, shuffle=True, **common_options)
    dataloader_val = torch.utils.data.DataLoader(val_embeddings, shuffle=False, **common_options)
    dataloader_test = torch.utils.data.DataLoader(test_embeddings, shuffle=False, **common_options) if data_module.has_test_set else None


    torch.distributed.destroy_process_group()
    trainer = pl.Trainer(devices = 1, num_nodes=1,
                        accelerator='gpu', max_epochs=args.probe_epochs,
                        callbacks=[checkpoint_callback, lr_monitor], logger=loggers, log_every_n_steps=10)
    
    # only train the probe on rank 0
    if trainer.is_global_zero:
        model = Linear_Probe(args, ssl_model.backbone, num_classes=data_module.num_classes)

        trainer.fit(model, dataloader_train, dataloader_val)
        if data_module.has_test_set:
            trainer.test(model, dataloader_test)

def main():
    args = parse()
    bookkeeping(args)

    # if resuming, set tensorboard version as the version of the checkpoint
    if args.resume_checkpoint_path is not None:
        version = f"version_{args.resume_checkpoint_path.split('/')[-2].split('_')[-1]}"
    else:
        version = None # have tensorboard create new version

    tb_logger = pl_loggers.TensorBoardLogger(save_dir=args.logs_folder, name=args.experiment_name, version=version)
    wandb_logger = create_wandb_logger(args)
    loggers = [tb_logger, wandb_logger] if wandb_logger is not None else [tb_logger]

    trained_ssl_model = train_ssl(args, loggers)
    train_probe(args, trained_ssl_model, loggers)

if __name__ == "__main__":
    main()
