import argparse
import torch
import random
from src.utils import *
from src.trainer import train_model
from src.evaluation import evaluate_model


parser = argparse.ArgumentParser(description="Train and evaluate self-distillation-based self-supervised learning models")
group = parser.add_mutually_exclusive_group(required=True)

# Model and dataset selection
group.add_argument('--model_name', type=str, choices=["SDMI", "BYOL", "SimSiam", "SimCLR", "MoCo", "SimSiam_SDMI", "JMI", "BarlowTwins", "VICReg"])
parser.add_argument('--dataset', type=str, required=True, choices=["CIFAR10", "CIFAR100", "ImageNet100", "ImageNet1000", "TinyImageNet"])
parser.add_argument('--architecture', type=str, default="ResNet18", help="Backbone architecture")

# Training and optimization parameters
parser.add_argument('--initial_lr', type=float, default=0.03)
parser.add_argument('--warmup_initial_lr', type=float, default=0.0)
parser.add_argument('--fixed_lr', action='store_true', help="Disables learning rate scheduling and uses fixed LR")
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--num_workers', type=int, default=32)
parser.add_argument('--epochs', type=int, default=2000)
parser.add_argument('--warmup_epochs', type=int, default=0)
parser.add_argument('--optimizer', type=str, default="SGD", choices=["SGD", "Adam"])
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=0.0005)

# Model architecture hyperparameters
parser.add_argument('--feature_dim', type=int, default=2048)
parser.add_argument('--projection_dim', type=int, default=256)
parser.add_argument('--prediction_dim', type=int, default=256)
parser.add_argument('--projection_layer', type=int, default=2)
parser.add_argument('--prediction_layer', type=int, default=2)
parser.add_argument('--num_classes', type=int, default=10)
parser.add_argument('--temperature', type=float, default=0.07, help="Temperature for contrastive losses")
parser.add_argument('--tau', type=float, default=0.996, help="EMA decay hyperparameter for target network update in BYOL or momentum encoder in MoCo")
parser.add_argument('--lambd', type=float, default=0.0051, help="BarlowTwins diversity loss hyperparameter")
parser.add_argument('--similarity_coeff', type=float, default=25.0, help="VICReg similarity loss hyperparameter")
parser.add_argument('--variance_coeff', type=float, default=25.0, help="VICReg variance loss hyperparameter")
parser.add_argument('--covariance_coeff', type=float, default=1.0, help="VICReg covariance loss hyperparameter")
parser.add_argument('--eps', type=float, default=1e-4, help="Small epsilon added inside √(·) for numerical stability in VICReg")

# Evaluation configuration
parser.add_argument('--evaluation_lr', type=float, default=0.2)
parser.add_argument('--evaluation_epochs', type=int, default=100)
parser.add_argument('--evaluation_batch_size', type=int, default=512)
parser.add_argument('--evaluation_weight_decay', type=float, default=0.0)

# Model saving and checkpointing
parser.add_argument('--model_save_interval', type=int, default=50)
parser.add_argument('--model_evaluation_interval', type=int, default=100)
parser.add_argument('--resume', type=str, default=None, help="Path to checkpoint for resuming training")

# General flags
parser.add_argument('--augmentation', action='store_true', help="Enable strong data augmentations")
group.add_argument('--supervised', action='store_true', help="Enable supervised training mode")
parser.add_argument('--seed', type=int, default=None)
parser.add_argument('--num_runs', type=int, default=1,help="Number of times to repeat the experiment with different random seeds")
parser.add_argument('--fast_debug', action='store_true', help="Enable fast debug mode (use only one fixed batch for quick experiments)")


def main():
    args = parser.parse_args()

    for run in range(1, args.num_runs + 1):
        run_paths = get_run_paths(args, run)
        logger = setup_logger(args.model_name, args.dataset, run_paths.log_directory)
        logger.info("===================================================================================")
        logger.info(f"Run-{run} started with arguments: {args}")
        logger.info("===================================================================================")

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logger.info(f"Using device: {device}")

        if args.seed is not None:
            random.seed(args.seed + run)
            torch.manual_seed(args.seed + run)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        training_and_evaluation_manager(args, logger, device, run_paths)


def training_and_evaluation_manager(args, logger, device, run_paths):
    train_loader, test_loader = get_dataloaders(args, logger, args.batch_size, args.augmentation, args.supervised)

    if args.model_name == "BYOL":
        args.max_steps = args.epochs * len(train_loader)
    
    model = build_model(args, logger, device)
    optimizer = get_optimizer(args, model, args.initial_lr, args.weight_decay, args.momentum)

    last_epoch = load_checkpoint(args, logger, model, optimizer, run_paths.checkpoint_directory, device)
    start_epoch = last_epoch + 1

    for epoch in range(start_epoch, args.epochs):
        if epoch == 0:
            evaluate_model(args, logger, device, run_paths, model, epoch, "linear_probing", test_loader)        
            graph_plotting_manager(args, run_paths, model)
            save_checkpoint(args, logger, model, optimizer, epoch, run_paths.checkpoint_directory)

        if not args.fixed_lr:
            adjust_lr(args, optimizer, epoch)

        train_model(args, logger, model, train_loader, optimizer, epoch, device)

        if ((epoch + 1) % args.model_evaluation_interval == 0 and epoch > 0) or (epoch == args.epochs - 1):
            evaluate_model(args, logger, device, run_paths, model, epoch, "linear_probing", test_loader)
            graph_plotting_manager(args, run_paths, model)

        if ((epoch + 1) % args.model_save_interval == 0 and epoch > 0) or (epoch == args.epochs - 1):
            save_checkpoint(args, logger, model, optimizer, epoch, run_paths.checkpoint_directory)
    
    if args.model_name == "SDMI":
        logger.info(f"E-encoder evaluation history: {model.E_linear_probing_history}")
        logger.info(f"M-encoder evaluation history: {model.M_linear_probing_history}")
    else:
        logger.info(f"Encoder evaluation history: {model.linear_probing_history}")


if __name__ == '__main__':
    main()
