"""
Save true features of the model.

Args:
    config_path (str): Path to the config file.
    --device (str): Device to use for computation (default: 'cuda').
"""
import argparse
import os
from typing import Any
import tqdm
from PIL import Image
import numpy as np
import torch
from scipy.io import savemat
import yaml

from metamer.icnn_replication.image_domain import PILDomainWithExplicitCrop
from metamer.reconstruct.models import load_encoder

IMAGE_SIZE = (224, 224)  # Default size for reconstruction


def parse_model_name(config) -> str:
    """
    Parse model name for the output directory.
    """
    # if alias is provided, use it
    if config['model'].get('model_alias'):
        return config['model']['model_alias']
    if config['model']['name'].endswith('-tfm'):
        # transformer model: use pretrained name by replacing '/' with '_'
        return config['model']['pretrained'].replace('/', '_')
    # otherwise, use the model name
    return config['model']['name']


def parse_feature_dir(config: dict[str, Any]) -> str:
    if 'feature_dir' in config:
        return config['feature_dir']
    model_name = parse_model_name(config)
    return os.path.join('output', 'readout_vision', 'features', model_name, config['data']['dataset_name']) 


def main(config: dict, device: str):
    # save all layers that will be potentially used
    layers = list(config['layer_mapping'].keys())  
    
    print('Loading model...')
    encoder = load_encoder(
        config['model'], 
        layers=layers,
        layer_mapping=config['layer_mapping'],
        device=device, dtype=torch.float32
    )
    print('Model loaded.')

    # image domain for pil images
    pil_domain = PILDomainWithExplicitCrop()

    # path
    feature_dir = parse_feature_dir(config)
    image_dir = config['data']['image_dir']
    image_names = config['data']['image_names']
    image_ext = config['data']['image_ext']

    for image_name in tqdm.tqdm(image_names):
        # Caution: The image should be resized before preprocessing to avoid different results
        image = Image.open(os.path.join(image_dir, f'{image_name}{image_ext}')).convert('RGB').resize(IMAGE_SIZE)

        # turn the pil image into a tensor
        image = torch.tensor(np.array(image)).unsqueeze(0).to(device)
        image = pil_domain.send(image)
        features = encoder(image)

        # Save features
        for layer, f in features.items():
            output_path = os.path.join(feature_dir, layer, f'{image_name}.mat')
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            savemat(output_path, {'feat': f.detach().to('cpu')})
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    parser.add_argument("--device", type=str, default='cuda')
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'image_names' not in config['data']:
        # use image_names_path
        with open(config['data']['image_names_path'], 'r') as f:
            config['data']['image_names'] = yaml.safe_load(f)

    main(config, device=args.device)