import torch
import torch.optim as op
import numpy as np
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.dataset import Subset
from IPython import embed


# mm_imdb
import models.search.mmimdb_darts_searchable as mmimdb
from datasets import mmimdb as mmimdb_data



from models.utils import parse_opts
import models.search.tools as tools


class MMIMDB_Searcher():
    def __init__(self, args, device, logger):
        self.args = args
        self.device = device
        self.logger = logger

        transformer_val = transforms.Compose([mmimdb_data.ToTensor()])
        transformer_tra = transforms.Compose([mmimdb_data.ToTensor()])

        dataset_training = mmimdb_data.MM_IMDB(args.datadir, transform=transformer_tra, stage='train', feat_dim=300, args=args)
        dataset_dev = mmimdb_data.MM_IMDB(args.datadir, transform=transformer_val, stage='dev', feat_dim=300, args=args)
        dataset_test = mmimdb_data.MM_IMDB(args.datadir, transform=transformer_val, stage='test', feat_dim=300, args=args)

        datasets = {'train': dataset_training, 'dev': dataset_dev, 'test': dataset_test}
        self.dataloaders = {
            x: DataLoader(datasets[x], batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers,
                          drop_last=False) for x in ['train', 'dev', 'test']}
    def search(self):
        best_f1, best_genotype = mmimdb.train_darts_model(self.dataloaders, self.args, self.device, self.logger)
        return best_f1, best_genotype


