import sys

import torch

from torchmeta.utils.data import BatchMetaDataLoader

from common.args import parse_args
from common.utils import get_optimizer, load_model
from data.dataset import get_meta_dataset
from models.model import get_model
from train.trainer import meta_trainer
from utils import Logger, set_random_seed, cycle, NoSampler_SubsetRandomSampler

from datetime import datetime
def main(rank, P):
    P.rank = rank

    """ set torch device"""
    if torch.cuda.is_available():
        torch.cuda.set_device(P.rank)
    device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

    """ fixing randomness """
    set_random_seed(P.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    """ define dataset and dataloader """
    kwargs = {'batch_size': P.batch_size, 'shuffle': True,
              'pin_memory': True, 'num_workers': 2}
    
    test_kwargs = {'batch_size': 1, 'shuffle': False,
               'pin_memory': True, 'num_workers': 2}
    train_set, val_set = get_meta_dataset(P, dataset=P.dataset)
    if P.regression:
        train_loader = train_set
        test_loader = val_set
    else:
        if P.subsampling:
            from utils import SubsetRandomSampler

            if P.use_sampler:
                train_sampler = SubsetRandomSampler(
                    train_set,
                    limit_batches=P.limit_train_batches,
                    batch_size=P.batch_size
                )
            else:
                train_sampler = NoSampler_SubsetRandomSampler(
                    torch.arange(len(train_set.dataset.dataset))
                )
            
            train_loader = BatchMetaDataLoader(
                train_set.dataset, 
                batch_size=P.batch_size,
                num_workers=4,
                drop_last=True,
                shuffle=False,
                sampler=train_sampler
            )
            if P.use_sampler:
                val_sampler = SubsetRandomSampler(
                    val_set,
                    limit_batches=P.limit_val_batches,
                    batch_size=1
                )
            else:
                val_sampler = NoSampler_SubsetRandomSampler(
                    torch.arange(len(val_set.dataset.dataset))
                )

            test_loader = BatchMetaDataLoader(
                val_set.dataset, 
                batch_size=1,
                num_workers=4,
                sampler=val_sampler,
                drop_last=True,
                shuffle=False,
            )
        else:
            if P.subset:
                from utils import SubsetRandomSampler

                train_sampler = SubsetRandomSampler(
                    torch.arange(len(train_set.dataset))
                )
                val_sampler = SubsetRandomSampler(
                torch.arange(len(val_set.dataset))
                )
                kwargs['sampler'] = train_sampler
                train_loader = cycle(BatchMetaDataLoader(train_set, **kwargs))
                kwargs['sampler'] = val_sampler
                test_loader = BatchMetaDataLoader(val_set, **test_kwargs)
            else:
                train_loader = cycle(BatchMetaDataLoader(train_set, **kwargs))
                test_loader = BatchMetaDataLoader(val_set, **test_kwargs)

    """ Initialize model, optimizer, loss_scalar (for amp) and scheduler """
    model = get_model(P, P.model).to(device)
    optimizer = get_optimizer(P, model)

    """ define train and test type """
    from train import setup as train_setup
    from evals import setup as test_setup
    train_func, fname, today = train_setup(P.mode, P)
    test_func = test_setup(P.mode, P)

    """ define logger """
    today_log = datetime.today().strftime("%y%m%d")
    save_path = P.folder_name + '/' + today_log + P.dataset + P.model + P.suffix
    logger = Logger(
        log_dir=save_path,
        exp_name=save_path.replace('/', '_'),
        exp_suffix='',
        write_textfile=False if P.folder_name == 'debug' else True,
        use_wandb= (not P.no_wandb) and rank == 0,
        wandb_project_name=P.wandb_project_name,
        entity=P.entity
    )
    logger.update_config(P, is_args=True)

    """ load model if necessary """
    load_model(P, model, logger)
    logger.watch(model)

    """ train """
    meta_trainer(P, train_func, test_func, model, optimizer, train_loader, test_loader, logger)

    """ close wandb """
    logger.save_log() 


if __name__ == "__main__":
    """ argument define """
    P = parse_args()

    P.world_size = torch.cuda.device_count()
    P.distributed = P.world_size > 1
    if P.distributed:
        print("currently, ddp is not supported, should consider transductive BN before using ddp",
              file=sys.stderr)
    else:
        main(0, P)
