import torch
import numpy as np
import random
import argparse

from torch.utils.data import DataLoader
# from pytorch_lightning.plugins import DDPPlugin
# from pytorch_lightning.strategies.ddp import DDPStrategy


from sessions import train, test
from datasets import get_dataset, get_transform, dataset_info
from models import get_network_by_name
from feature_aug import get_feature_aug_by_name

def main():
    parser = argparse.ArgumentParser(description='Investigate how feature augmentation would help model training')

    parser.add_argument('--lr', default=5e-4, type=float, help='learning rate')
    parser.add_argument('--model', default='resnet50', type=str, help='model type')
    parser.add_argument('--dataset', default='cifar10', type=str, help='dataset name')
    parser.add_argument('--dataset_path', default=None, type=str, help='dataset path')

    parser.add_argument('--save', '-save', default="./test", type=str, help='saving dir')
    parser.add_argument('--load', '-load', default=None, type=str, help='resume from checkpoint')
    parser.add_argument('--eval', '-eval', action='store_true',help='only eval')

    parser.add_argument('--max_epochs', '-m', default=200, type=int, help='max epochs')
    parser.add_argument('--batch_size', '-b', default=256, type=int, help='batch size')

    parser.add_argument('--num_workers', '-w', default=0, type=int, help='number of workers for the dataset')
    parser.add_argument('--num_devices', '-n', default=None, type=int, help='number of devices')
    parser.add_argument('--accelerator', '-a', default=None, type=str, help='accelerator style')
    parser.add_argument('--strategy', '-s', default='ddp', type=str, help='accelerator strategy')

    parser.add_argument('--debug', '-debug', action='store_true',help='debug')
    parser.add_argument('--seed', '-seed', default=0, type=int ,help='seed')


    parser.add_argument('--feature_aug', '-fa', default=None, type=str, help='feature augmentation strategy')
    parser.add_argument('--dropblock', '-db', default=None, type=str, help='dropblock style')
    parser.add_argument(
        '--feature_aug_config', '-fac', default='./configs/augmentation/default.yaml', 
        type=str, help='feature augmentation strategy config')
    parser.add_argument('--optimizer', '-opt', default='adam', type=str, help='optimizer setup')
    parser.add_argument('--imagenet_path', '-ip', default=None, type=str, help='imagenet path')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'


    # seeding
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)


    # get data and transforms
    train_transform, test_transform = get_transform(name=args.dataset)
    train_dataset = get_dataset(
        name=args.dataset, my_path=args.dataset_path, train=True, transform=train_transform, imagenet_path=args.imagenet_path)
    test_dataset = get_dataset(
        name=args.dataset, my_path=args.dataset_path, train=False, transform=test_transform, imagenet_path=args.imagenet_path)      

    # get data loader
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # build model
    print(f"Building model {args.model}")
    dropblock_style = None
    num_cls = dataset_info[args.dataset]['num_classes']
    if args.model in ['resnet50_aug', 'wresnet_aug', 'shakeshake_aug', 'resnet50_aug_imagenet']:
        # if the model uses feature augmentation
        if args.feature_aug is None:
            raise ValueError('feature_aug is required for resnet50_aug')
        my_transforms = get_feature_aug_by_name(
            name=args.feature_aug, config=args.feature_aug_config)
        dropblock_style = args.dropblock
        model_args = {'my_transforms': my_transforms}
        model = get_network_by_name(
            args.model, model_args=model_args, num_class=num_cls)
    else:
        model = get_network_by_name(args.model, num_class=num_cls)

    if args.eval:
        trainer_args = {
            'devices': args.num_devices, 
            'accelerator': args.accelerator,
            'strategy': args.strategy,
            # 'plugins':  DDPPlugin(find_unused_parameters=False),
        }
        test(
            model=model,
            test_loader=test_loader,
            trainer_args=trainer_args,
            load_path=args.load)
    else:
        # if args.strategy == 'ddp':
        #     strategy = DDPStrategy(find_unused_parameters=False)
        # else:
        #     strategy = args.strategy
        trainer_args = {
            'max_epochs': args.max_epochs,
            # 'optimizer': args.optimizer,
            'devices': args.num_devices, 
            'accelerator': args.accelerator,
            'strategy': args.strategy,
            'fast_dev_run': args.debug,
            # 'plugins':  DDPPlugin(find_unused_parameters=False),
        }
        if args.load is not None:
            trainer_args['resume_from_checkpoint'] = args.load

        train(
            model=model, 
            train_loader=train_loader, 
            val_loader=test_loader, 
            learning_rate=args.lr, 
            trainer_args=trainer_args,
            save_path=args.save,
            optimizer=args.optimizer,
            dropblock_style=dropblock_style)


if __name__ == '__main__':
    main()
