'''
This file requires a trained model for each class (discriminating between domains)
'''

import pickle
import os
import torch
import numpy as np
import torchvision.models as models
from tqdm import tqdm
from unlabeled_extrapolation.datasets.breeds import Breeds
from unlabeled_extrapolation.datasets.connectivity_utils import data, transforms
from prune import CustomDataset

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_dir', type=str, required=True)
parser.add_argument('--save_dir', type=str)
parser.add_argument('--save_file_name', type=str, required=True)
parser.add_argument('--domain', type=str, help='The name of the Breeds domain', required=True)
parser.add_argument('--batch_size', type=int, default=60)
parser.add_argument('--num_transforms', type=int, default=10)
parser.add_argument('--transform', type=str,
                    choices=['imagenet', 'simclr', 'swav-simclr'], default='swav-simclr',
                    help='Choice of data augmentation to use.')

ROOT = os.environ.get('BREEDS_ROOT')

def get_breeds_num_classes(domain):
    if domain == 'living17':
        return 17
    elif domain == 'entity30':
        return 30
    else:
        raise NotImplementedError

def load_single_class_dataset(dataset_name, domain, class_idx, transform):
    if dataset_name == 'breeds':
        data_attr = '_image_paths_by_class'
        source = Breeds(ROOT, domain, source=True, target=False, split='train', transform=transform)
        target = Breeds(ROOT, domain, source=False, target=True, split='train', transform=transform)
        data.filter_to_single_class(source, class_idx, data_attr)
        data.filter_to_single_class(target, class_idx, data_attr)
        return source, target, data_attr
    else:
        raise NotImplementedError()

def get_batch(ds, start, end):
    X, labels, paths = [], [], []
    for idx in range(start, end):
        tiled, label, path = ds[idx]
        X.append(tiled)
        labels.append(label)
        paths.append(path)
    X = torch.cat(X)
    return X, labels, paths

def main():
    args = parser.parse_args()
    num_classes = get_breeds_num_classes(args.domain)
    model = models.resnet50(num_classes=2)
    model = model.cuda()
    transform =  transforms.get_transforms(args.transform)
    all_accs, all_paths = [], []
    for class_idx in tqdm(range(num_classes)):
        # load the model
        ckpt = torch.load(os.path.join(args.ckpt_dir, f'same-class-{class_idx}-save-model-final'))
        sd = ckpt['model']
        model.load_state_dict(sd)
        # load the data
        source, target, data_attr = load_single_class_dataset('breeds', args.domain, class_idx, transform)
        source = CustomDataset(source, data_attr, 0, args.num_transforms)
        target = CustomDataset(target, data_attr, 1, args.num_transforms)
        
        curr_accs, curr_paths = [], []
        model.eval()
        with torch.no_grad():
            for ds in [source, target]:
                for idx in tqdm(range(0, len(ds), args.batch_size), leave=False):
                    X, labels, paths = get_batch(ds, idx, min(idx + args.batch_size, len(ds)))
                    X = X.cuda()
                    outs = model(X)
                    outs = outs.argmax(dim=1)
                    outs = outs.reshape(-1, args.num_transforms).detach().clone().cpu()
                    labels = torch.tensor(labels).view(-1, 1)
                    accs = ((outs == labels) * 1.0).mean(dim=1) # * 1.0 to make it floating type
                    curr_accs.extend(accs.tolist())
                    curr_paths.extend(paths)
        all_accs.append(curr_accs)
        all_paths.append(curr_paths)
    with open(os.path.join(args.save_dir, args.save_file_name), 'wb') as f:
        pickle.dump((all_accs, all_paths), f)

if __name__ == '__main__':
    main()
