import os

import torch

from engine.tools.utils import makedirs, save_as_json, load_json
from engine.transforms.default import build_transform
from engine.datasets.utils import DatasetWrapper, get_few_shot_benchmark
from engine.datasets import dataset_classes

@torch.no_grad()
def prepare_image_dataset(args, clip_model, batch_size=32, num_workers=4):
    """Prepare image features (few-shot train+val and full-shot test) for the benchmark."""
    # Check if (image) features are saved already
    few_shot_set_path = get_few_shot_set_path(args)
    test_set_path = get_test_set_path(args)
    # lab2cname_path = os.path.join(args.data_dir, args.dataset, 'lab2cname.json')
    lab2cname_path = os.path.join(args.data_dir, dataset_classes[args.dataset].dataset_name, 'lab2cname.pth')
    makedirs(os.path.dirname(test_set_path))

    transform = build_transform('none') # center crop only
    
    if not os.path.exists(few_shot_set_path) or \
        not os.path.exists(test_set_path) or \
        not os.path.exists(lab2cname_path):
        few_shot_benchmark = get_few_shot_benchmark(
            args.data_dir,
            args.indices_dir,
            args.dataset,
            args.shot,
            args.seed
        )
        lab2cname = few_shot_benchmark['lab2cname']
        torch.save(lab2cname, lab2cname_path)
    
    if not os.path.exists(few_shot_set_path):
        few_shot_set = {
            'train': {},
            'val': {},
        }
        
        print(f"Extracting features for train split ...")
        few_shot_set['train'] = extract_features(
            clip_model, few_shot_benchmark['train'], 
            transform, test_batch_size=batch_size, num_workers=num_workers)
        
        print(f"Extracting features for val split ...")
        few_shot_set['val'] = extract_features(
            clip_model, few_shot_benchmark['val'],
            transform, test_batch_size=batch_size, num_workers=num_workers)
    
        print(f"Saving few-shot image features to {few_shot_set_path}")
        torch.save(few_shot_set, few_shot_set_path)

    if not os.path.exists(test_set_path):
        print(f"Extracting features for test split ...")
        test_set = extract_features(
            clip_model, few_shot_benchmark['test'], 
            transform, test_batch_size=batch_size, num_workers=num_workers)
        print(f"Saving features to {test_set_path}")
        torch.save(test_set, test_set_path)
        
    few_shot_set = torch.load(few_shot_set_path)
    train_set = few_shot_set['train']
    val_set = few_shot_set['val']
    test_set = torch.load(test_set_path)
    lab2cname = torch.load(lab2cname_path)
    return train_set, val_set, test_set, lab2cname


def get_image_dir(args):
    image_dir = os.path.join(
        args.feature_dir,
        'image',
        args.clip_encoder.replace("/", "-"),
        args.dataset
    )
    return image_dir


def get_few_shot_set_path(args):
    few_shot_set_path = os.path.join(
        get_image_dir(args),
        f"shot_{args.shot}-seed_{args.seed}.pth"
    )
    return few_shot_set_path


def get_test_set_path(args):
    test_set_path = os.path.join(
        get_image_dir(args),
        "test.pth"
    )
    return test_set_path


def extract_features(clip_model, data_source, transform, test_batch_size=32, num_workers=4):
    dataset = {
        'features': torch.Tensor(),
        'labels': torch.Tensor(),
        'paths': [],
    }
    ######################################
    #   Setup DataLoader
    ######################################
    loader = torch.utils.data.DataLoader(
        DatasetWrapper(data_source, transform=transform),
        batch_size=test_batch_size,
        sampler=None,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
        pin_memory=torch.cuda.is_available(),
    )

    ########################################
    # Start Feature Extractor
    ########################################
    clip_model.eval()

    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):
            data = batch["img"].cuda()
            feature = clip_model.encode_image(data) # This is not L2 normed
            feature = feature.cpu()
            if batch_idx == 0:
                dataset['features'] = feature
                dataset['labels'] = batch['label']
                dataset['paths'] = batch['impath']
            else:
                dataset['features'] = torch.cat((dataset['features'], feature), 0)
                dataset['labels'] = torch.cat((dataset['labels'], batch['label']))
                dataset['paths'] = dataset['paths'] + list(batch['impath'])
    return dataset