
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import sys
import random

from torch.utils.data import TensorDataset


device = 'cuda' if torch.cuda.is_available() else 'cpu'

def compute_img_mean_std_per_class(img_set, im_size, num_classes):
    means = [torch.tensor([0.0, 0.0, 0.0]) for i in range(num_classes)]
    vars = [torch.tensor([0.0, 0.0, 0.0]) for i in range(num_classes)]
    counts = [0 for i in range(num_classes)]
    for i in range(len(img_set)):
        img, label = img_set[i]
        means[label] += img.sum(axis        = [1, 2])
        vars[label] += (img**2).sum(axis        = [1, 2])
        counts[label] += 1

    counts = [count * im_size[0] * im_size[1] for count in counts]

    total_means = [mean / count for (mean, count) in zip(means, counts)]
    total_vars  = [(var / count) - (total_mean ** 2) for (var, total_mean, count) in zip(vars, total_means, counts)]
    total_stds  = [torch.sqrt(total_var) for total_var in total_vars]

    return total_means, total_stds

def compute_img_mean_std(img_set, im_size, num_classes):
    mean = torch.tensor([0.0, 0.0, 0.0])
    var = torch.tensor([0.0, 0.0, 0.0])
    count = len(img_set) * im_size[0] * im_size[1]
    for i in range(len(img_set)):
        img, label = img_set[i]
        mean += img.sum(axis        = [1, 2])
        var += (img**2).sum(axis        = [1, 2])

    total_mean = mean / count
    total_var  = (var / count) - (total_mean ** 2)
    total_std  = torch.sqrt(total_var)
    

    return total_mean, total_std


def get_initial_normal(train_set, im_size, num_classes, ipc):

    # compute mean and std
    means, stds = compute_img_mean_std_per_class(train_set, im_size, num_classes)
    mean, std = compute_img_mean_std(train_set, im_size, num_classes)
    # print(means)
        
    #initialize random image
    image_syn_classes = []
    for c in range(num_classes):
        image_syn1 = torch.normal(mean=means[c][0], std=stds[c][0], size=(ipc, 1, im_size[0], im_size[1]), dtype=torch.float, requires_grad=False, device=device) # [2*50, 1, 256, 256]
        image_syn2 = torch.normal(mean=means[c][1], std=stds[c][1], size=(ipc, 1, im_size[0], im_size[1]), dtype=torch.float, requires_grad=False, device=device) # [2*50, 1, 256, 256]
        image_syn3 = torch.normal(mean=means[c][2], std=stds[c][2], size=(ipc, 1, im_size[0], im_size[1]), dtype=torch.float, requires_grad=False, device=device) # [2*50, 1, 256, 256]
        image_syn = torch.cat([image_syn1,image_syn2,image_syn3], dim=1).detach()
        image_syn[image_syn<0] = 0.0
        image_syn[image_syn>1] = 1.0
        for ch in range(3):
            image_syn[:, ch] = (image_syn[:, ch] - mean[ch])/std[ch]
        image_syn_classes.append(image_syn)
    image_syn = torch.cat(image_syn_classes, dim=0)
    label_syn = torch.tensor(np.array([np.ones(ipc)*i for i in range(num_classes)]), dtype=torch.long, requires_grad=False, device=device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
    
    return image_syn, label_syn


def total_variation(x, signed_image=True):
    if signed_image:
        x = torch.abs(x)
    dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
    dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
    return dx + dy

def l2_norm(x, signed_image=True):
    if signed_image:
        x = torch.abs(x)
    batch_size = len(x)
    loss_l2 = torch.norm(x.view(batch_size, -1), dim=1).mean()
    return loss_l2

class BNFeatureHook:
    """
    Implementation of the forward hook to track feature statistics and compute a loss on them.
    Will compute mean and variance, and will use l2 as a loss
    """
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        # hook co compute deepinversion's feature distribution regularization
        nch = input[0].shape[1]
        mean = input[0].mean([0, 2, 3])
        var = (input[0].permute(1, 0, 2,
                                3).contiguous().view([nch,
                                                      -1]).var(1,
                                                               unbiased=False))

        # forcing mean and variance to match between two distributions
        # other ways might work better, i.g. KL divergence
        r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
            module.running_mean.data - mean, 2)
        self.mean = mean
        self.var = var
        self.r_feature = r_feature
        # must have no output

    def close(self):
        self.hook.remove()



def distribution_matching_DP(image_real, image_syn, optimizer_img, channel, num_classes, im_size, ipc, minibatch_loader, microbatch_loader, net=None):

    
    ''' update synthetic data '''
    # loss = torch.tensor(0.0).to(device)
    minibatch_loaders = []
    for c in range(num_classes):
        img_real = image_real[c]
        label_syns_tmp = torch.tensor(np.array(np.ones(len(img_real))*c), dtype=torch.long, requires_grad=False, device=device).view(-1)
        minibatch_loaders.append(minibatch_loader(TensorDataset(img_real, label_syns_tmp)))


    train_iters = [iter(minibatch_loaders[i]) for i in range(len(minibatch_loaders))]
    loss_all = 0
    optimizer_img.zero_grad()
    for c in range(num_classes):
        img_syn = image_syn[c*ipc:(c+1)*ipc].reshape((ipc, channel, im_size[0], im_size[1]))
        for step in range(len(train_iters[c])):

            # default we use ConvNet
            if net == None:
                net = get_network('ConvNet', channel, num_classes, im_size).to(device) # get a random model
                net.train()
                # for param in list(net.parameters()):
                #     param.requires_grad = False
            else:
                net.train()
                # for param in list(net.parameters()):
                #     param.requires_grad = False

            embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel
            # get minibatch images for microbatch
            X_minibatch, y_minibatch = next(train_iters[c])
            for X_microbatch, y_microbatch in microbatch_loader(TensorDataset(X_minibatch, y_minibatch)):


                output_real = embed(X_microbatch).detach()
                output_syn = embed(img_syn)
        
                optimizer_img.zero_microbatch_grad()
                loss = torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)
                loss.backward()
                loss_all += loss.item()
                
                optimizer_img.microbatch_step()
    optimizer_img.step()
        
    
    return loss.item(), image_syn


def distribution_matching(image_real, image_syn, optimizer_img, channel, num_classes, im_size, ipc, image_server=None, net=None):

    lambda_sim = 0.5

    # default we use ConvNet
    if net == None:
        net = get_network('ConvNet', channel, num_classes, im_size).to(device) # get a random model
        net.train()
        # for param in list(net.parameters()):
        #     param.requires_grad = False
    else:
        net.train()
        # for param in list(net.parameters()):
        #     param.requires_grad = False

    embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel

    loss_avg = 0

    ''' update synthetic data '''
    loss = torch.tensor(0.0).to(device)
    for c in range(num_classes):
        img_real = image_real[c]
        img_syn = image_syn[c*ipc:(c+1)*ipc].reshape((ipc, channel, im_size[0], im_size[1]))
        
        seed = int(time.time() * 1000) % 100000
        dsa_param = ParamDiffAug()
        img_real = DiffAugment(img_real, 'color_crop_cutout_flip_scale_rotate', seed=seed, param=dsa_param)
        img_syn = DiffAugment(img_syn, 'color_crop_cutout_flip_scale_rotate', seed=seed, param=dsa_param)

        output_real = embed(img_real).detach()
        output_syn = embed(img_syn)

        if image_server is not None:
            img_server = image_server[c]
            img_server = DiffAugment(img_server, 'color_crop_cutout_flip_scale_rotate', seed=seed, param=dsa_param)
            output_server = embed(img_server).detach()
            server_client_loss = torch.sum((torch.mean(output_server, dim=0) - torch.mean(output_syn, dim=0))**2)
            loss += lambda_sim * server_client_loss
        
        loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)

    
    optimizer_img.zero_grad()
    loss.backward()
    
    optimizer_img.step()

    if image_server is not None:
        return loss.item(), image_syn, server_client_loss.item()
    else:
        return loss.item(), image_syn#, total_norm
    
