"""
Common functions used in the reconstruction scripts.
"""
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from pathlib import Path

from bdpy.dataform import Features
from metamer.icnn_replication.critic import DistsLoss, TargetNormalizedMSE, CombinationLoss, MSE, NormalizedMSE


def load_critic(config):
    """
    Args:
        config (dict): Configuration for critic.
    """
    if config['name'] == 'dist':
        alpha = 4.0
        beta = 1.0
        critic = CombinationLoss([DistsLoss(), TargetNormalizedMSE()], [alpha, beta])
    
    elif config['name'] == 'target_norm_mse':
        critic = TargetNormalizedMSE()

    elif config['name'] == 'mse':
        critic = MSE()
    
    elif config['name'] == 'normalized_mse':
        critic = NormalizedMSE()

    else:
        raise NotImplementedError(f"Critic {config['name']} is not implemented")
    return critic


def load_optimizer_and_scheduler(config, generator):
    if config['name'] == 'adamw':
        optimizer = optim.AdamW([{'params': param} for param in generator.parameters()], lr=config['lr'])
    elif config['name'] == 'adam':
        optimizer = optim.Adam([{'params': param} for param in generator.parameters()], lr=config['lr'])
    else:
        raise NotImplementedError(f"Optimizer {config['name']} is not implemented")
    
    if 'scheduler' in config:
        if config['scheduler']['name'] == 'LinearLR':
            print('Using LinearLR scheduler')
            params = {n: v for n, v in config['scheduler'].items() if n != 'name'}
            scheduler = optim.lr_scheduler.LinearLR(optimizer, **params)

        elif config['scheduler']['name'] == 'PolynomialLR':
            print('Using PolynomialLR scheduler')
            params = {n: v for n, v in config['scheduler'].items() if n != 'name'}
            scheduler = optim.lr_scheduler.PolynomialLR(optimizer, **params)
        
        else:
            raise NotImplementedError(f"Scheduler {config['scheduler']['name']} is not implemented")
    else:
        scheduler = None

    return optimizer, scheduler


def set_seed(seed: int):
    """
    Set seed in both numpy and torch for reconstruction.
    """
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


class DictImageDataset:
    """
    Image dataset that returns a image by its name.
    """
    def __init__(
        self,
        root_path: str,
        stimulus_names: list[str] | None = None,
        extension: str = ".jpg",
        transform=None,
    ):
        if stimulus_names is None:
            stimulus_names = [
                path.name.removesuffix(extension)
                for path in Path(root_path).glob(f"*{extension}")
            ]
        # load the data
        self.transform = transform
        self.data = {}
        for name in stimulus_names:
            image = Image.open(Path(root_path) / f"{name}{extension}")
            if self.transform is not None:
                image = self.transform(image)
            self.data[name] = image

    def __len__(self):
        return len(self.data)

    def __getitem__(self, stimulus_name):
        return self.data[stimulus_name]


class DictFeaturesDataset:
    """
    Feature dataset that returns features by its name.
    """
    def __init__(
        self,
        root_path: str,
        layer_path_names: list[str],
        stimulus_names: list[str] | None = None,
        transform=None,
        return_type='tensor'
    ):
        self.features_store = Features(Path(root_path).as_posix())
        self.layer_path_names = layer_path_names
        self.transform = transform
        self.stimulus_names = stimulus_names
        self.return_type = return_type

    def __len__(self):
        return len(self.stimulus_names)

    def __getitem__(self, stimulus_name: str):
        features = {}
        for layer_path_name in self.layer_path_names:
            feature = self.features_store.get(
                layer=layer_path_name, label=stimulus_name
            )
            feature = feature[0]  # NOTE: remove batch axis

            if self.return_type == 'tensor':
                feature = torch.tensor(feature)
            elif self.return_type == 'numpy':
                pass
            else:
                raise ValueError(f"Unknown return type: {self.return_type}")

            features[layer_path_name] = feature
        if self.transform is not None:
            features = self.transform(features)
        return features
