import os 
import torch
import torch.nn as nn
from models import ResNet18, ResNet32, ResNet50, Projector, LinearClassifier
from losses import SupConLoss, SupConCELoss
from args import args
# from data import train_loader,val_loader
from data import set_dataloader
from utils import get_logger

def main():
    METHOD = args.method

    if args.dataset in ['cifar10', 'mnist']:
        DEFAULT_NUM_CLASSES = 10 
        args.num_classes= 10
        DEFAULT_OUT_DIM = 128  # for ssl embedding space dimension
    elif args.dataset == 'cifar100':
        DEFAULT_NUM_CLASSES = 100
        args.num_classes= 100
        DEFAULT_OUT_DIM = 128
    elif args.dataset == 'tiny200':
        DEFAULT_NUM_CLASSES = 200
        args.num_classes= 200
        DEFAULT_OUT_DIM = 128
    elif args.dataset == 'BuS':
        DEFAULT_NUM_CLASSES = 2
        args.num_classes= 2
        DEFAULT_OUT_DIM = 128
    else:
        raise  ValueError('dataset not supported: {}'.format(args.dataset))
    
    # Model definition

    if args.method == 'sl':
        embed_only = False
    else:
        embed_only = True
        projector = Projector(name=args.backbone, out_dim=DEFAULT_OUT_DIM, device=args.device)
        classifier = LinearClassifier(name=args.backbone, num_classes=DEFAULT_NUM_CLASSES, device=args.device)

    if args.dataset == 'mnist':
        from models import Net4mnist
        encoder = Net4mnist().to(args.device)
    elif args.backbone == 'resnet18':
        PATH_TO_WEIGHTS = './contrastive_learning_cifar10-main/pnn/resnet18.pth'
        encoder = ResNet18(num_classes=DEFAULT_NUM_CLASSES, 
                       embed_only=embed_only,
                       from_scratch=args.scratch, 
                       path_to_weights=PATH_TO_WEIGHTS,
                       device=args.device)
    elif args.backbone == 'resnet32':
        PATH_TO_WEIGHTS = './contrastive_learning_cifar10-main/pnn/resnet32.pth'
        encoder = ResNet32(num_classes=DEFAULT_NUM_CLASSES,
                       embed_only=embed_only,
                       from_scratch=args.scratch,
                       path_to_weights=PATH_TO_WEIGHTS, 
                       device=args.device)
    elif args.backbone == 'resnet50':
        PATH_TO_WEIGHTS = './contrastive_learning_cifar10-main/pnn/resnet50.pth'
        encoder = ResNet50(num_classes=DEFAULT_NUM_CLASSES,
                       embed_only=embed_only, 
                       from_scratch=args.scratch,
                       path_to_weights=PATH_TO_WEIGHTS,
                       device=args.device)
    else:
        raise NotImplementedError('backbone not supported: {}'.format(args.backbone))

    train_loader, val_loader = set_dataloader(args)


    if METHOD == 'sl':
        # not in this file 
        raise NotImplementedError('method not supported: {}'.format(args.method))

    elif METHOD == 'scl':
        from hybrid_attack import HybridAT
        criterion = SupConLoss(temperature=args.tau, device=args.device)
        criterion_ce = nn.CrossEntropyLoss()
        
        model = HybridAT(
            encoder=encoder,
            projector=projector,
            classifier=classifier,
            train_loss_fn=criterion,
            val_loss_fn=criterion_ce,
            adv_rate=args.adv_rate,
            args=args,
            device=args.device
        )
        model.train(train_dataloader=train_loader, val_dataloader=val_loader)

    elif METHOD == 'hybrid':
        from hybrid_attack import HybridAT
        criterion = SupConCELoss(
            temperature=args.tau,
            alpha=args.alpha,
            device=args.device,
            num_classes=args.num_classes
        )
        criterion_ce = nn.CrossEntropyLoss()
        
        model = HybridAT(
            encoder=encoder,
            projector=projector,
            classifier=classifier,
            train_loss_fn=criterion,
            val_loss_fn=criterion_ce,
            adv_rate=args.adv_rate,
            args=args,
            device=args.device
        )
        model.train(train_dataloader=train_loader, val_dataloader=val_loader)

    else:
        raise NotImplementedError('Method not supported: {}'.format(METHOD))

    del encoder, projector, classifier
    if 'model' in locals():
        del model
    del train_loader, val_loader

if __name__ == '__main__':
    main()
