import math
import numpy as np
import torch
import torch.nn as nn
from torchvision.transforms import InterpolationMode
from bdpy.dl.torch.models import model_factory
from metamer.icnn_replication import image_domain
from metamer.icnn_replication.generator import (
    FrozenGenerator, noGenerator, noGeneratorInitImage
)
from metamer.icnn_replication.encoder import (
    Encoder, TransformerEncoder, CLIPEncoder
)


def load_model_and_domain(config, device, dtype=torch.float32):
    """
    Load the model and domain from the config.
    Note that the default generator generats image with size (256, 256).

    Args:
        config (dict): Configuration for the model. It should contain the following keys:
            - name: Name of the model architecture
            - pretrained: Name of the pretrained model checkpoint or path to the checkpoint
    """
    name = config['name']
    pretrained = config['pretrained']

    if name.startswith('vgg'):
        if name == 'vgg19-lab':
            # lab default VGG19 model
            # specify the path to the pretrained weights
            from bdpy.dl.torch.models import model_factory
            model = model_factory('vgg19')
            model.load_state_dict(torch.load(pretrained))

            domain=image_domain.ComposedDomain(
                [image_domain.BdPyVGGDomain(device=device, dtype=dtype),
                #  image_domain.FixedResolutionDomain((224, 224))
                image_domain.CenterCropDomain((224, 224))
                ]
            )
        elif name == 'vgg19':
            # vgg19 from torchvision
            from torchvision.models import vgg19
            model = vgg19(weights=pretrained)
            domain = image_domain.ComposedDomain(
                [
                    NormalizeDomain(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    ),
                    image_domain.CenterCropDomain((224, 224))
                ]
            )

        elif name == 'vgg19_no_norm':
            # vgg19 from torchvision
            from torchvision.models import vgg19
            model = vgg19(weights=pretrained)
            domain = image_domain.ComposedDomain(
                [
                    #NormalizeDomain(
                    #    mean=(0.485, 0.456, 0.406),
                    #    std=(0.229, 0.224, 0.225)
                    #),
                    image_domain.CenterCropDomain((224, 224))
                ]
            )
        elif name == 'vgg19_clamp':
            from torchvision.models import vgg19
            model = vgg19(weights=pretrained)
            domain = image_domain.ComposedDomain(
                [
                    image_domain.CenterCropDomain((224, 224)),
                    ClampDomain(min=0, max=1),
                    NormalizeDomain(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    ),
                ]
            )

    elif name.lower().startswith('resnet'):  # note some models use resnet as a part of the name
        # load resnet from torchvision
        # available weights contain "IMAGENET1K_V2" or "IMAGENET1K_V1"
        from torchvision.models import resnet50
        model = resnet50(weights=pretrained)
        domain = image_domain.ComposedDomain(
            [
                NormalizeDomain(
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)
                ),
                image_domain.CenterCropDomain((224, 224))
            ]
        )

    elif 'densenet' in name.lower():
        # variations of densenet is available in torchvision

        if name == 'densenet121':
            # weights are available for 'IMAGENET1K_V1'
            from torchvision.models import densenet121
            model = densenet121(weights=pretrained)
            domain = image_domain.ComposedDomain(
                [
                    NormalizeDomain(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    ),
                    image_domain.CenterCropDomain((224, 224))
                ]
            )
        
        elif name == 'densenet161':
            from torchvision.models import densenet161
            # weights are available for 'IMAGENET1K_V1'
            model = densenet161(weights=pretrained)
            domain = image_domain.ComposedDomain(
                [
                    NormalizeDomain(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    ),
                    image_domain.CenterCropDomain((224, 224))
                ]
            )

        elif name == 'densenet169':
            pass

        elif name == 'densenet201':
            pass


    elif 'efficientnet' in name.lower():
        if name == 'efficientnetb7':
            # representative model for efficientnet
            from torchvision.models import efficientnet_b7
            # weights 'IMAGENET1K_V1' is available
            model = efficientnet_b7(weights=pretrained)
            # efficientnet rescae the image to 600x600!
            domain = image_domain.ComposedDomain(
                [
                    image_domain.FixedResolutionDomain(
                        (600, 600),
                        interpolation=InterpolationMode.BICUBIC
                    ),
                    NormalizeDomain(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    )
                ]
            )

    elif 'googlenet' in name.lower():
        from torchvision.models import googlenet
        # weights 'IMAGENET1K_V1' is available
        model = googlenet(weights=pretrained)
        domain = image_domain.ComposedDomain(
            [
                image_domain.CenterCropDomain((224, 224)),
                NormalizeDomain(
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)
                )
            ]
        )

    elif 'convnext' in name.lower():
        if name == 'convnext_base':
            from torchvision.models import convnext_base
            # weights 'IMAGENET1K_V1' is available
            model = convnext_base(weights=pretrained)
            domain = image_domain.ComposedDomain(
                [
                    image_domain.CenterCropDomain((224, 224)),
                    NormalizeDomain(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    )
                ]
            )
        elif name == 'convnext_large':
            from torchvision.models import convnext_large
            # weights 'IMAGENET1K_V1' is available
            model = convnext_large(weights=pretrained)
            domain = image_domain.ComposedDomain(
                [
                    image_domain.CenterCropDomain((224, 224)),
                    NormalizeDomain(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    )
                ]
            )

    elif 'regnet' in name.lower():
        if name == 'regnet_y_128gf':
            from torchvision.models import regnet_y_128gf
            # weights 'IMAGENET1K_SWAG_E2E_V1' is available
            model = regnet_y_128gf(weights=pretrained)
            # resize to 384x384
            domain = image_domain.ComposedDomain(
                [
                    image_domain.FixedResolutionDomain(
                        (384, 384),
                        interpolation=InterpolationMode.BICUBIC
                    ),
                    NormalizeDomain(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    )
                ]
            )

        elif name == 'regnet_y_32gf':
            from torchvision.models import regnet_y_32gf
            # weights 'IMAGENET1K_SWAG_E2E_V1' and other weights are available
            # different weights may have different preprocessing
            if pretrained == 'IMAGENET1K_SWAG_E2E_V1':
                model = regnet_y_32gf(weights=pretrained)
                # resize to 384x384
                domain = image_domain.ComposedDomain(
                    [
                        image_domain.FixedResolutionDomain(
                            (384, 384),
                            interpolation=InterpolationMode.BICUBIC
                        ),
                        NormalizeDomain(
                            mean=(0.485, 0.456, 0.406),
                            std=(0.229, 0.224, 0.225)
                        )
                    ]
                )
            elif pretrained == 'IMAGENET1K_V2':
                model = regnet_y_32gf(weights=pretrained)
                # resize to 224x224
                domain = image_domain.ComposedDomain(
                    [
                        image_domain.CenterCropDomain((224, 224)),
                        NormalizeDomain(
                            mean=(0.485, 0.456, 0.406),
                            std=(0.229, 0.224, 0.225)
                        )
                    ]
                )
            else:
                raise NotImplementedError(f"Pretrained weights {pretrained} is not implemented.")

    elif name == 'vit-tfm':
        # Vision Transformer model from HuggingFace transformers
        from transformers import ViTImageProcessor, ViTForImageClassification
        model = ViTForImageClassification.from_pretrained(pretrained)
        processor = ViTImageProcessor.from_pretrained(pretrained)
        size = processor.size['height'], processor.size['width']
        domain = image_domain.ComposedDomain(
            [
                image_domain.CenterCropDomain(size),
                # normalize the image
                NormalizeDomain(
                    mean=processor.image_mean,
                    std=processor.image_std
                )
            ]
        )

    # Semantic segmentation models

    elif 'deeplabv3' in name.lower():
        print('deeplab')
        if name == 'deeplabv3_resnet50':
            from torchvision.models.segmentation import deeplabv3_resnet50
            model = deeplabv3_resnet50(weights=pretrained)            
            # rescale the image to 520x520
            domain = image_domain.ComposedDomain(
                [
                    image_domain.FixedResolutionDomain(
                        (520, 520),
                        interpolation=InterpolationMode.BILINEAR
                    ),
                    NormalizeDomain(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    )
                ]
            )

    # Object detection models
    elif 'fasterrcnn' in name.lower():
        if name == 'fasterrcnn_resnet50_fpn':
            from torchvision.models.detection import fasterrcnn_resnet50_fpn
            # weights='COCO_V1' is available
            model = fasterrcnn_resnet50_fpn(weights=pretrained)
            domain = CommonDomain() # no preprocessing for this model needed.

    # Contrastive learning models
    elif name == 'clip-tfm':
        from transformers import CLIPModel, CLIPProcessor
        model = CLIPModel.from_pretrained(pretrained)
        # processor = CLIPProcessor.from_pretrained(pretrained)
        # TODO: Use mean and std stored in the processor for safer normalization
        domain = image_domain.ComposedDomain(
            [
                # resize to 224x224
                image_domain.CenterCropDomain((224, 224)),
                # normalize the image
                NormalizeDomain(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711)
                )
            ]
        )

    # Other learning scheme models
    elif name.startswith('dinov2-tfm'):
        # DiNOv2 model from HuggingFace transformers
        # models, weights are available for 
        # facebook/dinov2-small, facebook/dinov2-base, facebook/dinov2-large, facebook/dinov2-giant
        # each with 22.1m, 86.6m, 304m, 1.14b parameters
        from transformers import AutoModel, AutoImageProcessor
        model = AutoModel.from_pretrained(pretrained)
        processor = AutoImageProcessor.from_pretrained(pretrained)
        domain = image_domain.ComposedDomain(
            [
                # resize to 224x224
                image_domain.CenterCropDomain((224, 224)),
                # normalize the image
                NormalizeDomain(
                    mean=processor.image_mean,
                    std=processor.image_std
                )
            ]
        )
    
    elif name == 'vae':
        from diffusers.models import AutoencoderKL
        model = AutoencoderKL.from_pretrained(pretrained)
        # no normalization needed for vae
        # input normalization: [0, 1] to [-1, 1]
        domain = image_domain.ComposedDomain(
            [
                image_domain.CenterCropDomain((224, 224)),
                NormalizeDomain(
                    mean=(0.5, 0.5, 0.5),
                    std=(0.5, 0.5, 0.5)
                )
            ]
        )
    
    elif name == 'toy_model_linear':
        model = ToyModelLinear()
        # load weights
        model.load_state_dict(torch.load(pretrained))
        domain = ...
        pass

    # Identity model for ablation studies
    elif name == 'identity':
        model = IdentityModel()
        domain = image_domain.Zero2OneImageDomain()

    else:
        raise NotImplementedError(f"Model {name} is not implemented.")

    # set the model mode and device
    model.to(device=device, dtype=dtype)
    model.eval()
    return model, domain


class NormalizeDomain(image_domain.ImageDomain):
    """Image domain for images with normalization"""

    def __init__(self, mean: tuple[float, ...], std: tuple[float, ...]):
        super().__init__()
        self.mean = mean
        self.std = std

    def send(self, images: torch.Tensor) -> torch.Tensor:
        mean = torch.tensor(self.mean, dtype=images.dtype, device=images.device)
        std = torch.tensor(self.std, dtype=images.dtype, device=images.device)
        mean = mean.view(1, -1, 1, 1)
        std = std.view(1, -1, 1, 1)
        return images * std + mean

    def receive(self, images: torch.Tensor) -> torch.Tensor:
        """Common space to normalize the images"""
        mean = torch.tensor(self.mean, dtype=images.dtype, device=images.device)
        std = torch.tensor(self.std, dtype=images.dtype, device=images.device)
        mean = mean.view(1, -1, 1, 1)
        std = std.view(1, -1, 1, 1)
        return (images - mean) / std        


class CommonDomain(image_domain.ImageDomain):
    """Use shared image domain for the model as well"""

    def __init__(self):
        super().__init__()

    def send(self, images: torch.Tensor) -> torch.Tensor:
        return images

    def receive(self, images: torch.Tensor) -> torch.Tensor:
        return images
    

class ClampDomain(image_domain.ImageDomain):
    """Clamp the image to a certain range"""

    def __init__(self, min: float, max: float):
        super().__init__()
        self.min = min
        self.max = max

    def send(self, images: torch.Tensor) -> torch.Tensor:
        """Unique space to common space: do nothing"""
        return images

    def receive(self, images: torch.Tensor) -> torch.Tensor:
        """Common space to unique space: clamp the images to the range"""
        return torch.clamp(images, min=self.min, max=self.max)


def load_encoder_class(config):
    """
    Load appropriate encoder class for the model.
    """
    if config['name'].startswith('clip'):
        # CLIP uses a different encoder,
        # which uses get_image_features method to extract features
        encoder_cls = CLIPEncoder

    elif 'dinov2' in config['name'] or 'vit' in config['name']:
        # transformer models use a different encoder
        # which turns tuple features into tensor features
        encoder_cls = TransformerEncoder

    elif config['name'] == 'vae':
        from metamer.icnn_replication.encoder import VAEEncoder
        encoder_cls = VAEEncoder

    else:
        encoder_cls = Encoder
    return encoder_cls


def load_encoder(config, layers, layer_mapping, device, dtype):
    """
    Load model and domain, and return them wrapped in Encoder.
    Args:
        config (dict): Config for the model.
    """
    model, domain = load_model_and_domain(config, device=device, dtype=dtype)
    # choose encoder class for the model

    encoder_cls = load_encoder_class(config)

    encoder = encoder_cls(
        feature_network=model,
        layer_names=layers,
        layer_mapping=layer_mapping,
        domain=domain, device=device,
    )
    return encoder


def load_generator(config, device, batch_size, dtype=torch.float32):
    """
    Args:
        config (dict): Configuration for generator.
    """
    if config['name'] == 'relu7generator':
        generator_network = model_factory(config['name'])
        generator_network.load_state_dict(torch.load(config["generator_param_file"]))
        generator_network.to(device)

        # load feature range file
        if '3x' in config['feature_range_file']:
            latent_upperbound = np.loadtxt(config['feature_range_file'], delimiter=" ")
            latent_upperbound = torch.tensor(latent_upperbound, device=device, dtype=dtype)
        elif '2x' in config['feature_range_file']:
            cols = 4096
            up_size = (4096,)
            upper_bound = np.loadtxt(config['feature_range_file'], delimiter=' ', usecols=np.arange(0, cols), unpack=True)
            upper_bound = upper_bound.reshape(up_size)
            latent_upperbound = torch.tensor(upper_bound, device=device, dtype=dtype)

        generator = FrozenGenerator(
            generator_network=generator_network,
            latent_shape=(batch_size, 4096),
            latent_upperbound=latent_upperbound,
            latent_lowerbound=0.0,
            device=device,
            dtype=dtype,
            domain=image_domain.BdPyVGGDomain(device=device, dtype=dtype),
        )
    elif config['name'] == 'nogenerator':
        generator = noGenerator(
            image_shape=(224, 224),
            batch_size=batch_size,
            device=device,
            dtype=dtype
        )
    elif config['name'] == 'nogenerator_init_image':
        # no generator with arbitrary initial image
        generator = noGeneratorInitImage(
            image_shape=(224, 224),
            batch_size=batch_size,
            device=device,
            dtype=dtype
        )
    elif config['name'] == 'deepimageprior':
        from metamer.icnn_replication.generator import DeepImagePriorGenerator
        use_sigmoid = config.get('use_sigmoid', True)
        generator = DeepImagePriorGenerator(
            image_shape=(224, 224),
            batch_size=batch_size,
            device=device,
            dtype=dtype,
            use_sigmoid=use_sigmoid
        )
        
    return generator


class SingleLayerNN(nn.Module):
    """
    Single layer neural network as a toy model.
    """
    def __init__(self, input_shape=(3, 224, 224), output_shape=(3, 224, 224)):
        super().__init__()
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.input_dim = math.prod(input_shape)
        self.output_dim = math.prod(output_shape)

        self.layer = nn.Linear(
            self.input_dim, self.output_dim
        )
    
    def forward(self, x: torch.Tensor):
        """
        Args:
            x (torch.Tensor): input image in common domain format.
        """
        out = self.layer(x)
        return out
    

class IdentityModel(nn.Module):
    """
    Return the exact same tensor by design.
    Can be used in ablation studies.
    """
    def __init__(self):
        super().__init__()
        self.identity = nn.Identity()

    def forward(self, x):
        return self.identity(x)