# -*- coding: utf-8 -*-

"""
DeepInversion

credits:
    https://github.com/NVlabs/DeepInversion
    https://github.com/GT-RIPL/AlwaysBeDreaming-DFCIL
"""

import math
from typing import Tuple, Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import trange

from .feature_hook import DeepInversionFeatureHook
from .gaussian_smoothing import Gaussiansmoothing
from .generator import create as gen_create


class GenerativeInversion(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        num_classes: int,
        dataset: str,
        device: str,
        batch_size: int = 256,
        max_iters: int = 5000,
        lr: float = 1e-3,
        alpha_proto: float = 0.2,
        protos: List[torch.Tensor] = None,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.num_classes = num_classes
        self.dataset = dataset
        self.max_iters = max_iters
        self.lr = lr
        self.alpha_proto = alpha_proto
        self.feature_hooks = []

        self.model = model
        self.protos = torch.stack(protos).to(device)
        self.generator = gen_create(dataset, self.protos)
        self.device = device
        self.generator = self.generator.to(self.device)
        print(f'Generator: {self.generator}')

    def setup(self):
        freeze(self.model)
        self.register_feature_hooks()

    def register_feature_hooks(self):
        # Remove old before register
        for hook in self.feature_hooks:
            hook.remove()

        ## Create hooks for feature statistics catching
        for module in self.model.modules():
            if isinstance(module, nn.BatchNorm2d):
                self.feature_hooks.append(DeepInversionFeatureHook(module))

    def criterion_rf(self):
        #  return sum([hook.r_feature for hook in self.feature_hooks])
        return torch.stack([h.r_feature for h in self.feature_hooks]).mean()
    
    def criterion_proto(self, features: torch.Tensor, targets: torch.Tensor):
        features = F.normalize(features, p=2, dim=1)
        with torch.no_grad():
            protos = F.normalize(self.protos, p=2, dim=1).detach()
        logits_proto = features @ protos.T * 50
        loss_proto = custom_cross_entropy(logits_proto, targets.detach())
        return loss_proto
    
    def generate_ys_in(self, batch_size: int, cr=0.0):
        s = batch_size // self.num_classes
        v = batch_size % self.num_classes
        target = torch.randint(self.num_classes, (v,))
        for i in range(s):
            tmp_label = torch.tensor(range(0, self.num_classes))
            target = torch.cat((tmp_label, target))

        ys = torch.zeros(batch_size, self.num_classes)
        ys.fill_(cr / (self.num_classes - 1))
        ys.scatter_(1, target.data.unsqueeze(1), (1 - cr))

        return target, ys

    def generate_ys(self, batch_size: int, cr=0.0):
        s = batch_size // self.num_classes
        v = batch_size % self.num_classes
        target = torch.randint(self.num_classes, (v,))
        for i in range(s):
            tmp_label = torch.tensor(range(0, self.num_classes))
            target = torch.cat((tmp_label, target))

        ys = torch.zeros(batch_size, self.num_classes)
        ys.fill_(cr / (self.num_classes - 1))
        ys.scatter_(1, target.data.unsqueeze(1), (1 - cr))

        return target, ys
    
    @torch.no_grad()
    def sample(self, batch_size: int = None):
        _ = self.model.eval() if self.model.training else None
        batch_size = self.batch_size if batch_size is None else batch_size
        if self.dataset == "imagenet100":
            targets, ys = self.generate_ys_in(batch_size, cr=0.0)
        else:
            targets, ys = self.generate_ys(batch_size, cr=0.0)
        targets, ys = targets.to(self.protos.device), ys.to(self.protos.device)
        inputs = self.generator.sample(targets)
        return inputs, targets

    def train_step(self):
        batch_size = 128
        if self.dataset == "imagenet100":
            targets, ys = self.generate_ys_in(batch_size, cr=0.0)
        else:
            targets, ys = self.generate_ys(batch_size, cr=0.0)
        targets, ys = targets.to(self.protos.device), ys.to(self.protos.device)
        inputs = self.generator.sample(targets)
        output = self.model(inputs)
        features = output["features"]

        loss_proto = self.alpha_proto * self.criterion_proto(features, ys)
        # feature statistics regularization
        loss_rf = self.criterion_rf()

        loss = loss_proto + loss_rf

        loss_dict = {
            "proto": loss_proto,
            "rf": loss_rf,
            "total": loss,
        }

        return loss, loss_dict

    def configure_optimizers(self):
        params = self.generator.parameters()
        return optim.Adam(params, lr=self.lr)

    def forward(self):
        _ = self.setup(), unfreeze(self.generator)
        optimizer = self.configure_optimizers()
        miniters = max(self.max_iters // 100, 1)
        pbar = trange(self.max_iters, miniters=miniters, desc="Inversion")
        for current_iter in pbar:
            loss, loss_dict = self.train_step()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (current_iter + 1) % miniters == 0:
                pbar.set_postfix({k: f"{v:.4f}" for k, v in loss_dict.items()})
        freeze(self.generator)
        
def custom_cross_entropy(preds, target):
    return torch.mean(torch.sum(-target * preds.log_softmax(dim=-1), dim=-1))

def freeze(module: nn.Module, mock_training: bool = False):
    """Freeze a torch Module

    1) save all parameters's current requires_grad state,
    2) disable requires_grad,
    3) turn on mock_training
    4) switch to evaluation mode.
    """

    state = {}
    for name, param in module.named_parameters():
        state[name] = param.requires_grad
        param.requires_grad = False
        param.grad = None

    if mock_training and hasattr(module, "mock_training"):
        module.mock_training = True

    module.eval()
    return state


def unfreeze(module: nn.Module, state: Dict[str, bool] = {}):
    """Unfreeze a torch Module

    1) restore all parameters's requires_grad state,
    2) switch to training mode.
    3) turn off mock_training

    """

    default = None if state else True
    for name, param in module.named_parameters():
        requires_grad = state.get(name, default)
        if requires_grad is not None:
            param.requires_grad = requires_grad

    module.train()

    if hasattr(module, "mock_training"):
        module.mock_training = False