import os 
import torch
import torch.nn as nn
from models import ResNet18, ResNet32, ResNet34, ResNet50, Projector, LinearClassifier
from losses import ConLoss, SupConLoss, SupConCELoss
from args import args
# from data import train_loader,val_loader
from data import set_dataloader
from utils import get_logger

def main():
    METHOD = args.method

    if args.dataset in ['cifar10', 'mnist']:
        DEFAULT_NUM_CLASSES = 10
        args.num_classes = 10
        DEFAULT_OUT_DIM = 128
    elif args.dataset == 'cifar100':
        DEFAULT_NUM_CLASSES = 100
        args.num_classes = 100
        DEFAULT_OUT_DIM = 128
    elif args.dataset == 'tiny200':
        DEFAULT_NUM_CLASSES = 200
        args.num_classes = 200
        DEFAULT_OUT_DIM = 128
    elif args.dataset == 'BuS':
        DEFAULT_NUM_CLASSES = 2
        args.num_classes = 2
        DEFAULT_OUT_DIM = 128
    else:
        raise ValueError('dataset not supported: {}'.format(args.dataset))
    
    embed_only = True 
    projector = Projector(name=args.backbone, out_dim=DEFAULT_OUT_DIM, device=args.device)
    classifier = LinearClassifier(name=args.backbone, num_classes=DEFAULT_NUM_CLASSES, device=args.device)
    weights_paths = {
        'resnet18': './contrastive_learning_cifar10-main/pnn/resnet18.pth',
        'resnet32': './contrastive_learning_cifar10-main/pnn/resnet32.pth',
        'resnet34': './contrastive_learning_cifar10-main/pnn/resnet34.pth',
        'resnet50': './contrastive_learning_cifar10-main/pnn/resnet50.pth'
    }
    PATH_TO_WEIGHTS = weights_paths[args.backbone]
    # Model definition
    if args.dataset == 'mnist':
        from models import Net4mnist
        encoder = Net4mnist().to(args.device)
    elif args.backbone == 'resnet18':
        encoder = ResNet18(embed_only=embed_only,
                       from_scratch=args.scratch,
                       path_to_weights=PATH_TO_WEIGHTS,
                       device=args.device)
    elif args.backbone == 'resnet32':
        encoder = ResNet32(embed_only=embed_only,
                       from_scratch=args.scratch,
                       path_to_weights=PATH_TO_WEIGHTS,
                       device=args.device)
    elif args.backbone == 'resnet34':
        encoder = ResNet34(embed_only=embed_only,
                       from_scratch=args.scratch,
                       path_to_weights=PATH_TO_WEIGHTS,
                       device=args.device)
    elif args.backbone == 'resnet50':
        encoder = ResNet50(embed_only=embed_only,
                       from_scratch=args.scratch,
                       path_to_weights=PATH_TO_WEIGHTS,
                       device=args.device)
    else:
        raise NotImplementedError('backbone not supported: {}'.format(args.backbone))

   

    train_loader, val_loader = set_dataloader(args)


    logger = get_logger(logpath=f"{args.save_folder}/logs.log", displaying=False)
    logger.info(f"Training method: {METHOD}")
    logger.info(f"Dataset: {args.dataset}")


    if METHOD == 'sl':
        from hybrid import Hybrid
        criterion = nn.CrossEntropyLoss()
        logger.info("Starting supervised learning...")
        model = Hybrid(
            encoder=encoder,
            projector=projector,
            classifier=classifier,
            train_loss_fn=criterion,
            val_loss_fn=criterion,
            args=args,
            device=args.device
        )

    elif METHOD == 'scl':
        from hybrid import Hybrid
        criterion = SupConLoss(temperature=args.tau, device=args.device)
        criterion_ce = nn.CrossEntropyLoss()
        logger.info("Starting supervised contrastive learning...")
        model = Hybrid(
            encoder=encoder,
            projector=projector,
            classifier=classifier,
            train_loss_fn=criterion,
            val_loss_fn=criterion_ce,
            args=args,
            device=args.device
        )

    elif METHOD == 'hybrid':
        from hybrid import Hybrid
        criterion = SupConCELoss(
            temperature=args.tau,
            alpha=args.alpha,
            device=args.device,
            num_classes=args.num_classes
        )
        criterion_ce = nn.CrossEntropyLoss()
        logger.info("Starting hybrid training...")
        model = Hybrid(
            encoder=encoder,
            projector=projector,
            classifier=classifier,
            train_loss_fn=criterion,
            val_loss_fn=criterion_ce,
            args=args,
            device=args.device
        )
    
    elif METHOD == 'ae':
        from models import AutoEncoder
        from hybrid import Hybrid
        logger.info("Starting autoencoder training...")
        

        if args.dataset in ['cifar10', 'cifar100']:
            target_size = args.size
        elif args.dataset == 'tiny200':
            target_size = args.size
        elif args.dataset == 'BuS':
            target_size = args.size
        else:
            target_size = args.size  
            

        ae_model = AutoEncoder(encoder_name=args.backbone,
                       num_classes=args.num_classes,
                       embed_only=embed_only,
                       from_scratch=args.scratch,
                       path_to_weights=PATH_TO_WEIGHTS,
                       device=args.device,
                       target_size=target_size)  
        
        criterion = nn.MSELoss() 
        model = Hybrid(
            encoder=ae_model.encoder,  
            projector=ae_model.decoder,    
            classifier=classifier,   
            train_loss_fn=criterion,
            val_loss_fn=criterion,
            args=args,
            device=args.device
        )

    elif METHOD == 'cl':
        from hybrid import Hybrid
        criterion = ConLoss(temperature=args.tau, device=args.device)
        logger.info("Starting contrastive learning...")
        model = Hybrid(
            encoder=encoder,
            projector=projector,
            classifier=classifier,
            train_loss_fn=criterion,
            val_loss_fn=criterion,
            args=args,
            device=args.device
        )

    else:
        raise NotImplementedError('method not supported: {}'.format(METHOD))


    model.train(train_dataloader=train_loader, val_dataloader=val_loader)


    del train_loader, val_loader

if __name__ == '__main__':
    main()
