import copy
import numpy as np
from scipy.io import loadmat
from collections import OrderedDict

import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models

from utils.utils import Adam16
from utils import single2tensor4
from networks.usrnet import USRNet
from networks.discriminator import Discriminator
from networks.generator import Ccgenerator, Generator
from networks.metanet import MetaNN


class Attacker:
    def __init__(self, config, criterion):
        """
        Initialize the attacker. 
        """
        self.sf = 4
        self.alpha_tv = 3.e-7
        self.T_max = config.T_max
        self.half = config.half
        
        self.device = config.device
        self.sample_size = config.sample_size
        self.eta = config.rog_lr
        self.printevery = config.printevery
        self.criterion = criterion

        self.device = config.device
        self.fed_lr = config.fed_lr
        self.resize = T.Resize((config.sample_size[0], config.sample_size[1]))
        
        # new parts
        self.post_cfg = getattr(config, "postprocess", None)
        self.post_enabled = bool(self.post_cfg and self.post_cfg.get("enable", False))

        # --- in init_attacker_models ---
        if not self.post_enabled:
            self.generator = None
            self.denoiser = None
            self.usrnet   = None
            self.kernel   = None
            self.sigma    = None
            return

        # Conditional generator
        gen_ckpt = self.post_cfg.get("generator_ckpt", None)
        if gen_ckpt:
            # num_classes must match your dataset
            self.generator = Ccgenerator(in_channels=3, num_classes=config.num_classes)
            state = torch.load(gen_ckpt, map_location="cpu")
            self.generator.load_state_dict(state["gen_state_dict"], strict=True)
            self.generator.eval().to(self.device)
            for p in self.generator.parameters(): p.requires_grad = False
        else:
            self.generator = None

        # Denoiser (optional)
        den_ckpt = self.post_cfg.get("denoiser_ckpt", None)
        if den_ckpt:
            self.denoiser = Generator()
            state = torch.load(den_ckpt, map_location="cpu")
            self.denoiser.load_state_dict(state["gen_state_dict"], strict=True)
            self.denoiser.eval().to(self.device)
            for p in self.denoiser.parameters(): p.requires_grad = False
        else:
            self.denoiser = None

    def init_attacker_models(self, config):
        """
        Load post-processing models. 
        """
        # conditional generator
        self.postmodel_dir = config.joint_postmodel
        self.generator = Ccgenerator()
        state_dict = torch.load(config.joint_postmodel)
        self.generator.load_state_dict(state_dict["gen_state_dict"])
        self.generator.eval()
        for p in self.generator.parameters():
            p.requires_grad = False

        self.generator.to(self.device)

        # denoising block
        self.denoiser = Generator()
        state_dict = torch.load(config.denoiser)
        self.denoiser.load_state_dict(state_dict["gen_state_dict"])
        self.denoiser.eval()
        for p in self.denoiser.parameters():
            p.requires_grad = False

        self.denoiser.to(self.device)

        # super resolution block
        self.usrnet = USRNet(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
                                nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
        state_dict = torch.load(config.usrnet)
        self.usrnet.load_state_dict(state_dict)
        self.usrnet.eval()
        for p in self.usrnet.parameters():
            p.requires_grad = False

        self.usrnet.to(self.device)
        
        kernels = loadmat(config.kernel)['kernels']
        kernel = kernels[0, self.sf-2].astype(np.float64)
        self.kernel = single2tensor4(kernel[..., np.newaxis])

        self.sigma = torch.tensor(config.noise_level).float().view([1, 1, 1, 1])

    def free_models(self):
        """
        Remove all post-processing models in the memory.
        """
        self.generator = None
        self.denoiser = None
        self.usrnet = None
        self.kernel = None
        self.sigma = None

        torch.cuda.empty_cache()

    def init_classifier(self):
        """
        Initialize the imagenet classifier
        """
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
        self.transform = T.Compose([normalize, T.Resize(256)])

        efficientnet_b7 = models.efficientnet_b7(pretrained=True)
        efficientnet_b7.eval()
        for p in efficientnet_b7.parameters():
            p.requires_grad_(False)

        efficientnet_b7.to(self.device)
        self.classifier = efficientnet_b7

    def init_discriminator(self):
        """
        Initialize the discriminator
        """
        self.discriminator = Discriminator()

        state = torch.load(self.postmodel_dir)
        self.discriminator.load_state_dict(state["dis_state_dict"])

        self.discriminator.to(self.device)

    def grad_inv(self, grad, x, onehot, model, logger=None):
        """
        Perform gradient inversion attack.
        """

        sample_size = self.sample_size
        latent_data = np.random.rand(x.shape[0], 3, int(sample_size[0]/self.sf), int(sample_size[1]/self.sf))

        if self.half:
            latent_data = latent_data.astype(np.float16)
        else:
            latent_data = latent_data.astype(np.float32)

        dummy_data = torch.from_numpy(latent_data).to(self.device).requires_grad_(True)
        
        if self.half:
            dummy_data_optimizer = Adam16([dummy_data], lr=self.eta)
        else:
            dummy_data_optimizer = optim.Adam([dummy_data], lr=self.eta)

        scheduler = optim.lr_scheduler.CosineAnnealingLR(dummy_data_optimizer, T_max=self.T_max, eta_min=0.1*self.eta)

        normal_func = lambda x: x

        # initialize a list to store the loss

        for iters in range(self.T_max):
            def closure():
                dummy_data_optimizer.zero_grad()

                pred = model(F.interpolate(dummy_data, scale_factor=self.sf, mode='bicubic'))
                dummy_loss = self.criterion(pred, onehot) 
                dummy_dy_dx = torch.autograd.grad(dummy_loss, model.parameters(), create_graph=True)

                grad_diff = 0
                for gx, gy in zip(dummy_dy_dx, grad):
                    grad_diff += ((normal_func(gx) - gy) ** 2).sum()

                grad_diff.backward()
                
                return grad_diff
        
            dummy_data_optimizer.step(closure)
            
            if iters % self.printevery == 0: 
                current_loss = closure()
                logger.info("iter: {:d} loss: {:.4e}".format(iters, current_loss.item()))
                
            if iters > 0:
                scheduler.step()

        # return the float32 tensor
        if self.half:
            dummy_data = dummy_data.to(torch.float32)

        return dummy_data.data

    def multistep_attack(self, grad, x, y, model, tau=1, logger=None):
        """
        Perform a multi-step gradient inversion attack.
        """

        sample_size = self.sample_size
        latent_data = np.random.rand(x.shape[0], 3, int(sample_size[0]/self.sf), int(sample_size[1]/self.sf))

        if self.half:
            latent_data = latent_data.astype(np.float16)
        else:
            latent_data = latent_data.astype(np.float32)

        dummy_data = torch.from_numpy(latent_data).to(self.device).requires_grad_(True)
        
        if self.half:
            dummy_data_optimizer = Adam16([dummy_data], lr=self.eta)
        else:
            dummy_data_optimizer = optim.Adam([dummy_data], lr=self.eta)

        scheduler = optim.lr_scheduler.CosineAnnealingLR(dummy_data_optimizer, T_max=self.T_max, eta_min=0.1*self.eta)

        for iters in range(self.T_max):
            def closure():
                dummy_data_optimizer.zero_grad()

                meta_net = MetaNN(model)
                net0 = copy.deepcopy(OrderedDict(model.named_parameters()))

                for i in range(tau):
                    pred = meta_net(F.interpolate(dummy_data, scale_factor=self.sf, mode='bicubic'))
                    dummy_loss = self.criterion(pred, y)
                    g = torch.autograd.grad(dummy_loss, meta_net.parameters.values(),
                                retain_graph=True, create_graph=True, only_inputs=True)

                    meta_net.parameters = OrderedDict((name, param - self.fed_lr * grad_part)
                                            for ((name, param), grad_part)
                                            in zip(meta_net.parameters.items(), g))

                meta_net.parameters = OrderedDict((name, (param_origin - param)/self.fed_lr)
                                    for ((name, param), (name_origin, param_origin))
                                    in zip(meta_net.parameters.items(), net0.items()))
                
                dummy_dy_dx = list(meta_net.parameters.values())
                
                grad_diff = 0
                counter = 0
                for gx, gy in zip(dummy_dy_dx, grad):
                    counter += 1
                    grad_diff += ((gx - gy) ** 2).sum()
                
                grad_diff.backward()

                return grad_diff
    
            dummy_data_optimizer.step(closure)

            if (iters % self.printevery == 0) and iters > 0:
                current_loss = closure()
                logger.info("iter: {:d} loss: {:.4e}".format(iters, current_loss.item()))
                
            if iters > 0:
                scheduler.step()

        if self.half:
            dummy_data = dummy_data.to(torch.float32)

        return dummy_data.data


    # def joint_postprocess(self, x, y):
    #     with torch.no_grad():
    #         refine1 = self.generator(self.resize(x), y)
    #         refine2 = self.denoiser(x)

    #         kernel = self.kernel.repeat(x.shape[0], 1, 1, 1)
    #         sigma = self.sigma.repeat(x.shape[0], 1, 1, 1)
    #         [kernel, sigma] = [el.to(self.device) for el in [kernel, sigma]]            
    #         refine2 = self.usrnet(refine2, kernel, self.sf, sigma)

    #     return refine1, 0.5*refine2 + 0.5*refine1

    def joint_postprocess(self, x, y):
        if not self.post_enabled:
            return x, x  # identity

        with torch.no_grad():
            # Start from a bicubic-upsampled version to the final evaluation size
            up_x = self.resize(x)  # -> [B,3,128,128]

            out_gen = self.generator(up_x, y) if self.generator else up_x

            if self.denoiser and self.usrnet:
                den = self.denoiser(x)
                kernel = self.kernel.repeat(x.size(0),1,1,1).to(self.device)
                sigma  = self.sigma.repeat(x.size(0),1,1,1).to(self.device)
                sr = self.usrnet(den, kernel, self.sf, sigma)  # -> 128x128
                fused = 0.5 * sr + 0.5 * out_gen
            elif self.denoiser and not self.usrnet:
                # fall back: denoise at 32, then bicubic to 128, then fuse
                den = self.denoiser(x)
                den_up = self.resize(den)
                fused = 0.5 * den_up + 0.5 * out_gen
            else:
                fused = out_gen

        return out_gen, fused  # (synth_data, recon_data)

    def refine(self, x, y):
        pass

def grad_inv(attacker, grad, x, onehot, model, config, logger):
    if config.fedalg != "fedavg":
        dummy_data = attacker.grad_inv(grad, x, onehot, model, logger)
    else:
        dummy_data = attacker.multistep_attack(grad, x, onehot, model, config.tau, logger)

    return dummy_data
