"""
Create adversarial image using reconstruction pipeline.

TODO: 


"""

from __future__ import annotations


from pathlib import Path
import argparse
import json
import os
from filelock import FileLock
from itertools import product
from typing import Any
import traceback

import numpy as np
import yaml
import pandas as pd
import torch
import wandb
from torchvision.transforms import Compose, Resize, ToTensor
from PIL import Image

from bdpy.dataform import Features

from metamer.icnn_replication.critic import Critic, ImageCritic
from metamer.icnn_replication import image_domain
from metamer.icnn_replication.pipeline import FeatureInversionPipeline
from metamer.icnn_replication.generator import noGeneratorInitImage
from metamer.reconstruct.models import load_generator, load_encoder_class, load_model_and_domain
from metamer.reconstruct.recon_common import (
    load_optimizer_and_scheduler, set_seed,
    DictImageDataset, DictFeaturesDataset
)

from metamer.icnn_replication.evaluation import (
    PixelCorrelation, PixelCosineSimilarity, PixelMSE,
    TrueFeatureCosineSimilarity, TrueFeatureCorrelation, TrueFeatureMSE
)


class TargetNormalizedAdversarialMSE(Critic):
    """MSE over vectors normalized by the target norm."""

    def criterion(
        self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str
    ) -> torch.Tensor:
        target_norm = (
            (target_feature**2)
            .sum(dim=tuple(range(1, target_feature.ndim)), keepdim=True)
            .sqrt()
        )
        f = feature / target_norm
        tf = target_feature / target_norm
        return -1 * (f - tf).pow(2).sum(dim=tuple(range(1, target_feature.ndim)))
    

class PixelMSERegulation(ImageCritic):
    """
    Mean squared error loss for pixel values as pixel regulation.
    """
    def __init__(self, true_images: torch.Tensor = None, weight: float = 1.0, domain=None):
        super().__init__()
        self.true_images = true_images
        self.weight = weight
        self.domain = domain

    def set_true_images(self, true_images: torch.Tensor):
        self.true_images = true_images

    def enable_wandb(self):
        self.with_wandb = True

    def criterion(self, generated_images: torch.Tensor) -> torch.Tensor:
        if self.domain is not None:
            # convert the generated images to the common domain
            generated_images = self.domain.receive(generated_images)
        
        # turn [0, 1] to [0, 255]
        true_images = self.true_images * 255
        generated_images = generated_images * 255

        mse = torch.nn.functional.mse_loss(
            true_images,
            generated_images,
            reduction='none'
        ).mean(dim=(1, 2, 3))
        return mse * self.weight
    

class FeatureAdversarialCosineSimilarityLoss(Critic):
    """
    Cosine similarity loss for feature vectors as adversarial loss.
    Returns cosine similarity loss between the feature and target feature.
    Minimizing the similarity means maximizing the difference between the feature and target feature.
    """

    def criterion(
            self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str
    ):
        # calculate cosine similarity
        cos = torch.nn.functional.cosine_similarity(
            feature.flatten(start_dim=1), 
            target_feature.flatten(start_dim=1), 
            dim=1
        )
        return cos

    
class PixelAdversarialCosineRegularization(ImageCritic):
    """
    -1 * cosine similarity between true image and generated image.
    Minimizing this loss means minimzing the difference between the true image and generated image.
    Can be used for regularization in creating adversarial images.

    Args:
        true_images (torch.Tensor): True images to compare with generated images. In common domain.
        weight (float): Weight for the loss.
    """
    def __init__(self, true_images: torch.Tensor = None, domain=None, weight=1.0):
        super().__init__()
        if true_images is not None:
            self.true_images = true_images.clone()
        self.domain = domain
        self.weight = weight
    
    def set_true_images(self, true_images: torch.Tensor):
        self.true_images = true_images.clone()
    
    def criterion(self, generated_images: torch.Tensor) -> torch.Tensor:
        if self.domain is not None:
            # convert the generated images to the common domain
            generated_images = self.domain.receive(generated_images)
        # calculate cosine similarity
        cos = torch.nn.functional.cosine_similarity(
            self.true_images.view(self.true_images.shape[0], -1),
            generated_images.view(generated_images.shape[0], -1),
            dim=1
        )
        # returns mean of samples
        return -1 * cos * self.weight
    

def load_true_images_and_features(image_name, image_dataset, features_dataset):
    true_images = image_dataset[image_name]
    true_features = features_dataset[image_name]

    # batchfy the data
    true_images = true_images.unsqueeze(0)
    true_features = {k: v.unsqueeze(0) for k, v in true_features.items()}
    return true_images, true_features


def main(config_path=None, config=None, device=None, re_run_error=False):
    if config is None:
        with open(config_path) as f:
            config = yaml.safe_load(f)

    if device is not None:
        config['device'] = device

    # Load dataset
    # we don't use dataloader in this script
    stimulus_names = config['image_names']
    image_dataset = DictImageDataset(
        root_path=config['image_dir'],
        extension=".JPEG",
        stimulus_names=stimulus_names,
        transform=Compose([Resize((224, 224)), ToTensor()])
    )
    all_layers_used = set(sum(config['layers'].values(), []))
    features_dataset = DictFeaturesDataset(
        root_path=config['feature_dir'],
        layer_path_names=all_layers_used,
        stimulus_names=stimulus_names,
        transform=None,
        return_type='tensor'
    )
    
    # load model and encoder class
    device = config['device']
    dtype = torch.float32
    model_name = config['model_alias'] if 'model_alias' in config else config['model']['name']
    
    # Load encoder class: we have to initialize the encoder for each layer setting
    model, domain = load_model_and_domain(config['model'], device=config['device'], dtype=dtype)
    encoder_class = load_encoder_class(config['model'])

    # Critics for feature and image
    critic_feature = FeatureAdversarialCosineSimilarityLoss()
    critic_image = PixelAdversarialCosineRegularization(weight=config['pixel_weight'], domain=image_domain.CenterCropDomain((224, 224)))

    # wandb settings
    use_wandb = 'wandb' in config

    # path
    exp_name = config['experiment_name']
    exp_path = f'output/adversarial_image/{model_name}/{exp_name}/'

    for seed, (layer_set_name, layers) in product(config['seeds'], config['layers'].items()):
        # load encoder for this layer setting
        encoder = encoder_class(feature_network=model, layer_names=layers, layer_mapping=config['layer_mapping'], domain=domain, device=device)
        
        for image_name in stimulus_names:
            # set seed 
            set_seed(seed)

            # get true image and true feature
            true_images, true_features = load_true_images_and_features(image_name, image_dataset, features_dataset)
            true_images = true_images.to(device=device, dtype=dtype)
            true_features = {k: v.to(device=device, dtype=dtype) for k, v in true_features.items()}
            ref_true_features = {k: v.clone() for k, v in true_features.items()}

            # set true image for generator and pixel critic
            generator = noGeneratorInitImage(image_shape=(224, 224), batch_size=1, device=device, dtype=dtype)
            generator.set_images(true_images.clone())
            critic_image.set_true_images(true_images.clone())

            # set eval metrics for this batch
            eval_metrics = [
                TrueFeatureCosineSimilarity(true_features=ref_true_features),
                TrueFeatureCorrelation(true_features=ref_true_features),
                TrueFeatureMSE(true_features=ref_true_features),
            ]
            pixel_eval_metrics = [
                PixelCorrelation(true_images=true_images.clone(), domain=image_domain.CenterCropDomain((224, 224))),
                PixelCosineSimilarity(true_images=true_images.clone(), domain=image_domain.CenterCropDomain((224, 224))),
                PixelMSE(true_images=true_images.clone(), domain=image_domain.CenterCropDomain((224, 224))),
            ]

            if use_wandb:
                wandb_name = f'{exp_name}_{layer_set_name}_seed{seed}_pixel_weight_{config["pixel_weight"]}'
                wandb_names = [image_name]
                wandb.init(project=config['wandb']['project'], name=wandb_name, config=config)
            else:
                wandb_names = None

            # instantiate the generator everytime because the batch size is different
            optimizer, scheduler = load_optimizer_and_scheduler(config['optimizer'], generator)

            # initialize pipeline for this run
            pipeline = FeatureInversionPipeline(
                generator=generator,
                encoder=encoder,
                critic=critic_feature,
                critic_image=critic_image,
                optimizer=optimizer,
                scheduler=scheduler,
                num_iterations=config['num_iterations'],
                log_interval=-1,
                with_wandb=use_wandb,
                eval_metrics=eval_metrics,
                pixel_eval_metrics=pixel_eval_metrics,
                wandb_log_interval=config['wandb']['log_interval'] if 'wandb' in config else 1
            )
            #pipeline.reset_states()  # Does reset state reset init image of the generator?

            generated_images = pipeline(true_features, wandb_names=wandb_names)
            generated_images = image_domain.finalize(generated_images)
            history = pipeline.history

            # save history and images
            image_path = os.path.join(exp_path, layer_set_name, image_name, f'seed_{seed}', 'final.png')
            hist_path = os.path.join(exp_path, layer_set_name, image_name, f'seed_{seed}', 'history.csv')
            sum_path = os.path.join(exp_path, layer_set_name, image_name, f'seed_{seed}', 'summary.json')

            # save image
            image = generated_images[0].detach().cpu().numpy().astype(np.uint8)
            # crop the image to 224x224
            h, w, _ = image.shape
            h_start = (h - 224) // 2
            w_start = (w - 224) // 2
            image = image[h_start : h_start + 224, w_start : w_start + 224]
            image = Image.fromarray(image)
            os.makedirs(os.path.dirname(image_path), exist_ok=True)
            image.save(image_path)

            # save history as csv
            hist = [
                {name: value[0] for name, value in h.items()} for h in history
            ]
            with open(sum_path, 'w') as f:
                json.dump(hist[-1], f, indent=4)
            hist = pd.DataFrame(hist)
            hist.to_csv(hist_path, index=False)

            wandb.finish()


        # delete encoder for this layer setting
        # to avoid features accumulate in feature_extractor
        del encoder.feature_extractor
        del encoder



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Reconstruct images from features.')
    parser.add_argument('--config_path', type=str, help='Path to the configuration file.')
    parser.add_argument('--device', type=str, required=False, help='Overwrite the device in the config file.')
    parser.add_argument('--re_run_error', action='store_true', help='Re-run the experiments that are marked as error.')
    args = parser.parse_args()
    main(config_path=args.config_path, device=args.device, re_run_error=args.re_run_error)
