import pickle
from tqdm import tqdm

import torch
from torch.utils.data import Dataset
import torchvision.models as models

from unlabeled_extrapolation.datasets.connectivity_utils import transforms, train, data
from unlabeled_extrapolation.datasets import breeds, domainnet

import argparse
parser = argparse.ArgumentParser('For pruning')
# Model args
parser.add_argument('--ckpt_path', type=str, help='Path of trained network weights to use')
# Specific args
parser.add_argument('--num_transforms', type=int, default=10)
# Dataset args
parser.add_argument('--dataset_name', type=str, required=True,
                    choices=['breeds', 'domainnet'],
                    help='Which dataset on which to test connectivity.')
parser.add_argument('--data_path', type=str,
                    help='Root path of the data.')
parser.add_argument('--domain_1', type=str, required=True,
                    help='Name of domain 1.')
parser.add_argument('--domain_2', type=str,
                    help='Name of domain 2 for domainnet.')
parser.add_argument('--version', type=str,
                    choices=['full', 'sentry'], default='sentry',
                    help='For domainnet, which version to use.')
parser.add_argument('--transform', type=str,
                    choices=['imagenet', 'simclr', 'swav-simclr'], default='swav-simclr',
                    help='Choice of data augmentation to use.')
# Other args
parser.add_argument('--batch_size', type=int, default=60,
                    help='Number of examples to do forward pass on at a time')
parser.add_argument('--save_path', type=str, required=True,
                    help='Where to save the results')

def main():
    args = parser.parse_args()

    # get model
    model = models.__dict__['resnet50'](num_classes=2)
    model = model.cuda()
    state_dict = torch.load(args.ckpt_path)['model']
    model.load_state_dict(state_dict)

    # get data
    transform =  transforms.get_transforms(args.transform)
    data.fill_default_data_path(args)
    ds1, ds2, data_attr = load_dataset(args, transform)
    ds1 = CustomDataset(ds1, data_attr, 0, args.num_transforms)
    ds2 = CustomDataset(ds2, data_attr, 1, args.num_transforms)

    # do stuff
    all_accs, all_paths = [], []
    model.eval()
    with torch.no_grad():
        for ds in [ds1, ds2]:
            for idx in tqdm(range(0, len(ds), args.batch_size)):
                X, labels, paths = get_batch(ds, idx, min(idx + args.batch_size, len(ds)))
                X = X.cuda()
                outs = model(X)
                outs = outs.argmax(dim=1)
                outs = outs.reshape(-1, args.num_transforms).detach().clone().cpu()
                labels = torch.tensor(labels).view(-1, 1)
                accs = ((outs == labels) * 1.0).mean(dim=1) # * 1.0 to make it floating type
                all_accs.extend(accs.tolist())
                all_paths.extend(paths)
    print(f'Saving to {args.save_path} now...')
    with open(args.save_path, 'wb') as f:
        pickle.dump((all_accs, all_paths), f)

def get_batch(ds, start, end):
    X, labels, paths = [], [], []
    for idx in range(start, end):
        tiled, label, path = ds[idx]
        X.append(tiled)
        labels.append(label)
        paths.append(path)
    X = torch.cat(X)
    return X, labels, paths

def load_dataset(args, transform):
    if args.dataset_name == 'breeds':
        data_attr = '_image_paths_by_class'
        # first get source train
        source_train = breeds.Breeds(
            args.data_path, args.domain_1, source=True, target=False,
            split='train', transform=transform)
        # then get target train
        target_train = breeds.Breeds( # domain_2 == domain_1 for breeds
            args.data_path, args.domain_1, source=False, target=True,
            split='train', transform=transform)
        return source_train, target_train, data_attr
    else:
        raise NotImplementedError()

class CustomDataset(Dataset):
    def __init__(
        self, dataset, data_attr_name, label, num_transforms
    ):
        super().__init__()
        self.dataset = dataset
        self.data_attr_name = data_attr_name
        self.label = label
        self.num_transforms = num_transforms

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        tiled = torch.stack([self.dataset[i][0] for _ in range(self.num_transforms)])
        path, _ = getattr(self.dataset, self.data_attr_name)[i]
        return tiled, self.label, path

if __name__ == '__main__':
    main()
