"""
Save true features of a model using lab default format.

Args:
    config_path (str): Path to the config file.

Config example
--------------
```yaml
device: cuda
model_name: CLIP  # name of the model class
model:  # model parameters
    model_name: 'ViT-B-32'
    pretrained: 'laion2b_s34b_b79k'
layers:  # list of layers to extract features from
- transformer.resblocks.0
- transformer.resblocks.1
- transformer.resblocks.2
- transformer.resblocks.3
...
feature_dir: 'features/imagenet'  # directory to save features
image_dir: 'images/imagenet'  # directory to load images from
image_names:  # list of image names to extract features from (without extension)
- n01443537_22563  # goldfish
- n02128385_20264  # leopard
- n01943899_24131  # sea shell
- n03345837_12501  # fire extinguisher
```
"""
import argparse
import os
import tqdm
from PIL import Image
import numpy as np
import torch
from scipy.io import savemat
import yaml

from metamer.icnn_replication.encoder import Encoder
from metamer.icnn_replication.image_domain import PILDomainWithExplicitCrop
from metamer.reconstruct.models import load_encoder

IMAGE_SIZE = (224, 224)  # Default size for reconstruction


def main(config_path):
    # load config
    with open(config_path, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    device = config['device']

    print('Loading model...')
    all_used_layers = list(config['layer_mapping'].keys())
    encoder = load_encoder(
        config['model'], 
        layers=all_used_layers,
        layer_mapping=config['layer_mapping'],
        device=device, dtype=torch.float32
    )
    print('Model loaded.')

    # image domain for pil images
    pil_domain = PILDomainWithExplicitCrop()

    save_dir = config['feature_dir']
    os.makedirs(save_dir, exist_ok=True)
    image_dir = config['image_dir']
    if 'image_names' in config:
        image_names = config['image_names']
    else:
        image_names = [f for f in os.listdir(image_dir) if f.endswith('.JPEG')]
        image_names = [name.split('.')[0] for name in image_names]

    for image_name in tqdm.tqdm(image_names):
        # Caution: The image should be resized before preprocessing to avoid different results
        image = Image.open(os.path.join(image_dir, f'{image_name}.JPEG')).convert('RGB').resize(IMAGE_SIZE)

        # turn the pil image into a tensor
        image = torch.tensor(np.array(image)).unsqueeze(0).to(device)
        image = pil_domain.send(image)
        features = encoder(image)

        # Save features
        for layer, f in features.items():
            output_path = os.path.join(config['feature_dir'], layer, f'{image_name}.mat')
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            savemat(output_path, {'feat': f.detach().to('cpu')})
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Save true features of a model.')
    parser.add_argument('--config_path', type=str, help='Path to the config file.')
    args = parser.parse_args()
    main(args.config_path)