"""Datasets"""
def initialize_dataset(config, **kwargs):
    """
    Returns dataset_train, dataset_test
    """
    name = config.dataset
    if name == 'amazon':
        from datasets.amazon_wpds import AmazonWPDS
        return AmazonWPDS(root_dir=config.root_dir, download=config.download, **kwargs)
    elif name == 'fmow':
        from datasets.fmow_wpds import FMoWWPDS
        return FMoWWPDS(root_dir=config.root_dir, download=config.download, **kwargs)
    elif name == 'civilcomments':
        from datasets.civilcomments_wpds import CivilCommentsWPDS
        return CivilCommentsWPDS(root_dir=config.root_dir, download=config.download, **kwargs)
    elif name == 'poverty':
        from datasets.poverty_wpds import PovertyWPDS
        return PovertyWPDS(root_dir=config.root_dir, download=config.download, **kwargs)
    else:
        raise NotImplementedError


"""Transforms"""
import torchvision
import transforms
def initialize_transform(config, dataset, train=True):
    transform_name = config.transform
    
    if transform_name == 'cifar100':
        if train:
            return torchvision.transforms.Compose([
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
            ])
        else:
            return torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
            ])
    else:
        return transforms.initialize_transform(transform_name, config, dataset, train)



"""Algorithms"""
import algs

def initialize_algorithm(config, model) -> algs.OCLAlgorithm:
    name = config.alg
    assert name is not None
    if name == 'fbo':
        return algs.FBO(model, config)
    elif name == 'nbo':
        return algs.NBO(model, config)
    elif name == 'er-fifo':
        return algs.ER_FIFO(model, config)
    elif name == 'er-fifo-rw':
        return algs.ER_FIFO_RW(model, config)
    elif name == 'gem-pds':
        return algs.GEM_PDS(model, config)
    elif name == 'mir':
        return algs.MIR(model, config)
    elif name == 'maxloss':
        return algs.MaxLoss(model, config)
    elif name == 'l2reg':
        return algs.L2Reg(model, config)
    elif name == 'ewc':
        return algs.EWC(model, config)
    elif name == 'nbo-pl':
        return algs.PseudoLabeling(model, config, algs.NBO)
    elif name == 'er-fifo-pl':
        return algs.PseudoLabeling(model, config, algs.ER_FIFO)
    elif name == 'er-fifo-fm':
        config.fixmatch = True
        return algs.PseudoLabeling(model, config, algs.ER_FIFO)
    elif name == 'er-fifo-rw-pl':
        return algs.PseudoLabeling(model, config, algs.ER_FIFO_RW)
    elif name == 'er-fifo-rw-fm':
        config.fixmatch = True
        return algs.PseudoLabeling(model, config, algs.ER_FIFO_RW)
    else:
        raise NotImplementedError


"""Models"""
from models.initializer import initialize_bert_based_model, initialize_torchvision_model
def initialize_model(config, dataset):
    name = config.model
    if name in ('resnet18', 'resnet34', 'resnet50', 'resnet101', 'wideresnet50', 'densenet121'):
        return initialize_torchvision_model(name, dataset.num_classes, **config.model_kwargs)
    elif name == 'resnet18_ms':  # multispectral resnet 18
        from models.resnet_multispectral import ResNet18
        model = ResNet18(num_classes=dataset.num_classes, **config.model_kwargs)
        return model
    elif 'bert' in name:
        import os
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        return initialize_bert_based_model(config, dataset.num_classes, **config.model_kwargs)
    else:
        raise NotImplementedError

    
"""Feedback"""
from feedback import RandomLabelFeedback
def initialize_feedback(config, eval_metric):
    name = config.feedback
    if name == 'rlf':
        return RandomLabelFeedback(config, eval_metric)
    else:
        raise NotImplementedError


"""Eval metric"""
def initialize_eval_metric(config):
    import wilds.common.metrics.all_metrics as all_metrics
    name = config.eval_metric
    if name == 'acc':
        return all_metrics.Accuracy(all_metrics.multiclass_logits_to_pred)
    elif name == 'pearson':
        return all_metrics.PearsonCorrelation()
    else:
        raise NotImplementedError


"""Loss function"""
import torch.nn as nn
def initialize_loss_function(config):
    name = config.loss_function
    if name == 'xent':
        return nn.CrossEntropyLoss(reduction='mean')
    elif name == 'mse':
        return nn.MSELoss(reduction='mean')
    else:
        raise NotImplementedError


"""Default command line argument parser"""
import argparse
from utils import ParseKwargs
def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int)

    # Dataset and OCL setup
    parser.add_argument('--dataset', required=True, type=str, help='Dataset name. Choices: civilcomments, fmow, amazon, poverty.')
    parser.add_argument('--root_dir', required=True, type=str, help='Path to the root directory of the dataset.')
    parser.add_argument('--download', default=False, action='store_true', help='Set this to download the dataset.')
    parser.add_argument('--transform', type=str, help='The transform method to be used.')
    parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--feedback', type=str, default='rlf', help='Type of OCL feedback.')
    parser.add_argument('--alpha', type=float, default=1.0, help='alpha in RLF. 0 <= alpha <= 1.')
    parser.add_argument('--csv_file', type=str, help='Set this to save the experiment results into a csv file.')

    # PDS setup
    parser.add_argument('--max_batches', type=int, help='If not set, use all batches.')
    parser.add_argument('--recent_batches', '-w', type=int, help='Recent time window.')

    # Evaluation
    parser.add_argument('--eval_metric', type=str, help='Name of the evaluation metric. Choices: acc, pearson.')
    parser.add_argument('--eval_batch_size', type=int, help='Batch size for evaluation.')
    parser.add_argument('--eval_post_train', default=False, action='store_true',
                        help='If set, evaluate the model on the online batch again after fine-tuning.')
    parser.add_argument('--eval_regression_once', default=False, action='store_true', 
                        help='If set, only evaluate regression set performance for the initial model.')
    
    # General training parameters (configs for ERM)
    parser.add_argument('--device', type=str, help='cuda or cpu, or a particular device.')
    parser.add_argument('--model', type=str, help='Model arch name.')
    parser.add_argument('--initial_model_load', '-i', type=str,
                        help='If specified, load this model at t = 0.')
    parser.add_argument('--initial_model_save', '-s', type=str,
                        help='If specified, save the initial model after training at t = 0.')
    parser.add_argument('--batch_size', type=int, help='Batch size for training.')
    parser.add_argument('--epochs', type=int, help='Number of epochs of the OCL algorithm.')
    parser.add_argument('--epochs_first_batch', type=int, help='Number of epochs of initial model training.')
    parser.add_argument('--loss_function', type=str, help='Name of the loss function. Choices: xent, mse.')

    # Optimizer
    parser.add_argument('--optimizer', type=str, help='Name of optimizer.')
    parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--lr', type=float, help='Learning rate.')
    parser.add_argument('--weight_decay', '--wd', type=float, help='Weight decay level.')
    parser.add_argument('--max_grad_norm', type=float, help='If set, do grad clipping.')

    # Scheduler
    parser.add_argument('--scheduler', type=str, help='Name of scheduler.')
    parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})

    # OCL algorithm
    parser.add_argument('--alg', type=str, help='Name of OCL algorithm. Choices: fbo, nbo, ' +
                                                'er-fifo, er-fifo-rw, gem-pds, mir, maxloss, ' + 
                                                'l2reg, ewc, nbo-pl, er-fifo-pl, er-fifo-fm, ' + 
                                                'er-fifo-rw-pl, er-fifo-rw-fm.')
    parser.add_argument('--epochs_unlabeled', type=int, help='Number of epochs of real update in PL and FM.')
    parser.add_argument('--kr_size', type=int)
    parser.add_argument('--lbd', type=float)
    parser.add_argument('--gamma', type=float)
    parser.add_argument('--fixmatch', default=False, action='store_true')

    # Misc
    parser.add_argument('--max_token_length', type=int, default=512)
    parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
                        help='keyword arguments for model initialization passed as key1=value1 key2=value2')
    parser.add_argument('--randaugment_n', type=int, help='Number of RandAugment transformations to apply.')
    parser.add_argument('--loss_function_dummy') # Don't set this
    return parser
