import argparse
import os
import click
import shutil
import numpy as np
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from wilds.datasets.waterbirds_dataset import WaterbirdsDataset
from wilds.datasets.celebA_dataset import CelebADataset
from wilds.datasets.coco_places_dataset import COCOonPlacesDataset
from models.initializer import initialize_model
import configs.supported as supported
from utils.misc import ParseKwargs, compute_gram_mat, compute_mean_std, compute_standard
from examples.utils.print_logger import get_logger
LOGGER = get_logger(__name__, level="DEBUG")

SELECTED_LAYERS = {
    'vgg19': [
        'features.0', 'features.5', 'features.10', 'features.19', 'features.28', # 'features.34'
    ],
    'resnet18': [
        'layer1.0.conv1', 'layer2.1.conv1', 'layer3.1.conv1', 'layer4.1.conv1'
    ],
    'resnet50': [
        # 'layer1.1.conv1', 'layer2.1.conv1','layer3.1.conv1',
        # 'layer4.1.conv1',
        'avgpool'
    ],
}

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--root_dir', required=True,
                        help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
    parser.add_argument("--dataset", type=str, default='waterbirds', help='')
    parser.add_argument("--split", type=str, default='train', help='Apply on train or val set')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
        help='keyword arguments for model initialization passed as key1=value1 key2=value2')
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--output_name', default=f'env_{datetime.now().strftime("%Y%m%d_%H%M")}')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--cluster_feature_type', choices=['standard', 'gram_matrix','mean_std'])
    parser.add_argument("--normalization", action="store_true", default=False, help="Continue recording features")
    parser.add_argument("--random_proj", type=int, default=0, help="Apply random projection")
    parser.add_argument("--resume", action="store_true", default=False, help="Continue recording features")
    parser.add_argument("--reload_proj", action="store_true", default=False, help="Reload saved projection matrices")
    config = parser.parse_args()

    # Directory management
    output_dir = os.path.join(config.log_dir, 'env_'+config.output_name)
    if config.resume:
        LOGGER.warning('Resuming clustering features creation')
    else:
        if Path(output_dir).exists():
            if click.confirm("Removing directory ? ({}).".format(output_dir), abort=True):
                shutil.rmtree(output_dir)
        os.mkdir(output_dir)

    if config.dataset == 'waterbirds':
        LOGGER.info('Loading Waterbirds dataset')
        dataset = WaterbirdsDataset(root_dir=config.root_dir, get_img_idx=True)
    elif config.dataset == 'celebA':
        LOGGER.info('Loading CelebA dataset')
        dataset = CelebADataset(root_dir=config.root_dir, get_img_idx=True)
    elif config.dataset == 'coco_on_places':
        LOGGER.info('Loading COCO-on-Places dataset')
        dataset = COCOonPlacesDataset(root_dir=config.root_dir, get_img_idx=True)
    else:
        raise ValueError(f"Dataset {config.dataset} not recognized")

    test_transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    split_data = dataset.get_subset(config.split, transform=test_transform)
    split_loader = DataLoader(
        split_data,
        shuffle=False,
        sampler=None,
        pin_memory=True,
        num_workers=10,
        collate_fn=dataset.collate,
        batch_size=config.batch_size)

    # for k, layer_selected in enumerate(SELECTED_LAYERS[config.model]):
    vect_gram_list, ind_list, activations = [], [], []
    # model = None
    torch.cuda.empty_cache()
    # print(f'------------- Dealing with {layer_selected} ({k+1}/{len(SELECTED_LAYERS[config.model])})')
    LOGGER.info(f'Loading {config.model} model' )
    model = initialize_model(config,dataset._n_classes)
    weight_file = os.path.join(config.log_dir, f'{config.dataset}_seed:{config.seed}_epoch:best_model.pth')
    checkpoint = torch.load(weight_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k, 'model.', ''): v for k, v in checkpoint['algorithm'].items()}
    model.load_state_dict(state_dict)
    model.cuda().eval()

    # Add hooking to model to get activations
    def get_activation():
        def hook(model, input, output):
            activations.append(output.detach())
        return hook

    hooked_layer_names = []
    for name, layer in model.named_modules():
        #if name==layer_selected:
        if name in SELECTED_LAYERS[config.model]:
            layer.register_forward_hook(get_activation())
            print(f'-- Hooked {config.model}.{name}')
            if 'relu' in name:
                name = name + '1'
            hooked_layer_names.append(name)

    # Save names and deal with special case of unregistered relu
    for i, name in enumerate(hooked_layer_names):
        if 'relu1' in name:
            hooked_layer_names.insert(i + 3, name[:-1]+'2')
            if config.model=='resnet50':
                hooked_layer_names.insert(i + 6, name[:-1] + '3')
    with open(os.path.join(output_dir, 'hooked_layers.txt'), 'a') as f:
        f.write('\n'.join(layer for layer in hooked_layer_names))


    LOGGER.info(f'Looping through {config.split} dataset')
    vect_cluster_features_list, ind_list, random_proj = [], [], []
    for i, (data, _, indices) in enumerate(tqdm(split_loader)):
        activations = []
        with torch.no_grad():
            _ = model(data.cuda())
        for j, feature in enumerate(activations):
            if i == 0:
                # Build projection matrix if needed
                # Define them once at first iteration !
                if config.random_proj > 0:
                    if config.reload_proj:
                        if j==0:
                            random_proj = torch.load(os.path.join(output_dir, f'projection_matrices.pt'))
                            random_proj = [random_proj[ind_proj_ma] for ind_proj_ma in range(len(activations))]
                    else:
                        nb_examples = len(split_data)
                        k0 = int(config.random_proj * np.log(nb_examples))
                        A = torch.rand(feature.shape[1] ** 2, k0, device='cuda') < .5
                        A = 2. * A.float() - 1.
                        print('Create projection matrix of size', [elem for elem in A.size()])
                        random_proj.append(A / np.sqrt(k0))
                else:
                    random_proj.append(None)
            if config.cluster_feature_type == 'standard':
                cluster_features = compute_standard(feature, random_proj=random_proj[j]).cpu()
            elif config.cluster_feature_type == 'gram_matrix':
                cluster_features = compute_gram_mat(feature, normalization=config.normalization,
                                                    random_proj=random_proj[j]).cpu()
            elif config.cluster_feature_type == 'mean_std':
                cluster_features = compute_mean_std(feature, normalization=config.normalization,
                                                    random_proj=random_proj[j]).cpu()
            else:
                raise ValueError(f"Cluster feature type  {config.cluster_feature_type} not recognized")

            if i==0:
                vect_cluster_features_list.append([cluster_features])
            else:
                vect_cluster_features_list[j].append(cluster_features)
        if i == 0 and config.random_proj > 0 and not config.reload_proj:
            LOGGER.info(f'Saving {len(random_proj)} projection matrices.')
            dict_to_save = {ind_ma_proj: ma_proj for ind_ma_proj, ma_proj in enumerate(random_proj)}
            torch.save(dict_to_save, os.path.join(output_dir, f'projection_matrices.pt'))

        ind_list.extend(indices.tolist())
    vect_cluster_features_list = [torch.cat(f) for f in vect_cluster_features_list]

    LOGGER.info(f'Saving clutering features in {output_dir}')
    for j,vect_cluster_features in enumerate(vect_cluster_features_list):
        out_filename = os.path.join(output_dir, f'{config.split}_vect_cluster_{hooked_layer_names[j]}.pt')
        torch.save(vect_cluster_features, out_filename)
    with open(os.path.join(output_dir, f'{config.split}_indices.txt'), 'w') as f:
        f.write('\n'.join(str(idx) for idx in ind_list))


if __name__ == "__main__":
    main()
