import pdb
import time
from math import sqrt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import torchvision
from tqdm import tqdm
from .utils import *
from .base import BaseSynthesis
from .hooks import DeepInversionHook
from tqdm import tqdm
import sys
import os
from utils import *
import shutil
from torch.utils.data import DataLoader, ConcatDataset


class Timer:
    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        if x >= 3600:
            return '{:.1f}h'.format(x / 3600)
        if x >= 60:
            return '{}m'.format(round(x / 60))
        return '{:.2f}s'.format(x)


def get_top_k_relative_indices_including_first(pre_attention, K):
    batch_size, N = pre_attention.shape
    K = min(K, N)
    remaining_attention = pre_attention
    top_values, top_indices = torch.topk(remaining_attention, K, dim=1)
    top_indices_adjusted = top_indices + 1
    first_index = torch.zeros((batch_size, 1), dtype=torch.long, device=pre_attention.device)
    top_k_indices = torch.cat((first_index, top_indices_adjusted), dim=1)
    return top_k_indices


def clip_images(image_tensor, mean, std):
    mean = np.array(mean)
    std = np.array(std)
    for c in range(3):
        m, s = mean[c], std[c]
        image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s)
    return image_tensor


def get_image_prior_losses(inputs_jit):
    # COMPUTE total variation regularization loss
    diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:]
    diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :]
    diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:]
    diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:]
    loss_var_l2 = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)
    loss_var_l1 = (diff1.abs() / 255.0).mean() + (diff2.abs() / 255.0).mean() + (
            diff3.abs() / 255.0).mean() + (diff4.abs() / 255.0).mean()
    loss_var_l1 = loss_var_l1 * 255.0
    return loss_var_l1, loss_var_l2


def jsdiv(logits, targets, T=1.0, reduction='batchmean'):
    P = F.softmax(logits / T, dim=1)
    Q = F.softmax(targets / T, dim=1)
    M = 0.5 * (P + Q)
    P = torch.clamp(P, 0.01, 0.99)
    Q = torch.clamp(Q, 0.01, 0.99)
    M = torch.clamp(M, 0.01, 0.99)
    return 0.5 * F.kl_div(torch.log(P), M, reduction=reduction) + 0.5 * F.kl_div(torch.log(Q), M, reduction=reduction)


def jitter_and_flip(inputs_jit, lim=1./8., do_flip=True):
    lim_0, lim_1 = int(inputs_jit.shape[-2] * lim), int(inputs_jit.shape[-1] * lim)
    # apply random jitter offsets
    off1 = random.randint(-lim_0, lim_0)
    off2 = random.randint(-lim_1, lim_1)
    inputs_jit = torch.roll(inputs_jit, shifts=(off1, off2), dims=(2, 3))
    # Flipping
    flip = random.random() > 0.5
    if flip and do_flip:
        inputs_jit = torch.flip(inputs_jit, dims=(3,))
    return inputs_jit, off1, off2, flip and do_flip


def jitter_and_flip_index(pre_index_matrix, off1, off2, flip, patch_size=16, num_patches_per_dim=14):
    off1_int, off1_frac = int(off1 // patch_size), off1 % patch_size / patch_size
    off2_int, off2_frac = int(off2 // patch_size), off2 % patch_size / patch_size
    patch_indices = torch.arange(1, num_patches_per_dim * num_patches_per_dim + 1)\
        .reshape(num_patches_per_dim, num_patches_per_dim).to(pre_index_matrix.device)

    patch_indices = torch.roll(patch_indices, shifts=(off1_int, off2_int), dims=(0, 1))
    if abs(off1_frac) >= 0.5:
        direction = 1 if off1_frac > 0 else -1
        patch_indices = torch.roll(patch_indices, shifts=(direction, 0), dims=(0, 1))
    if abs(off2_frac) >= 0.5:
        direction = 1 if off2_frac > 0 else -1
        patch_indices = torch.roll(patch_indices, shifts=(0, direction), dims=(0, 1))
    if flip:
        patch_indices = torch.flip(patch_indices, dims=[1])
    flat_patch_indices = patch_indices.flatten()
    non_zero_mask = pre_index_matrix != 0
    indices = (flat_patch_indices == pre_index_matrix[non_zero_mask].unsqueeze(-1)).nonzero(as_tuple=True)
    rows = indices[1] // num_patches_per_dim
    cols = indices[1] % num_patches_per_dim
    new_indices = rows * num_patches_per_dim + cols + 1
    new_index_matrix = torch.zeros_like(pre_index_matrix)
    new_index_matrix[non_zero_mask] = new_indices
    return new_index_matrix


class MI(BaseSynthesis):
    def __init__(self, args, clients, teachers, student, num_classes, img_shape=(3, 224, 224), patch_size=16,
                 iterations=4000, lr_g=0.25,
                 synthesis_batch_size=128, sample_batch_size=128,
                 adv=0, bn=0, oh=1, tv1=0.0, tv2=1e-4, l2=0,
                 save_dir='', save_dir2=None, transform=None,
                 normalizer=None, device='cpu',
                 bnsource='resnet50v2', init_dataset=None, test_loader=None):
        super(MI, self).__init__(teachers, student)
        assert len(img_shape)==3, "image size should be a 3-dimension tuple"
        self.args = args
        self.clients = clients
        self.save_dir = save_dir
        self.save_dir2 = save_dir2
        self.img_size = img_shape
        self.patch_size=patch_size
        self.iterations = iterations
        self.lr_g = lr_g
        self.normalizer = normalizer
        self.data_pool = ImagePool(root=self.save_dir)
        self.data_iter = None
        self.transform = transform
        self.synthesis_batch_size = synthesis_batch_size
        self.sample_batch_size = sample_batch_size
        self.init_dataset=init_dataset
        self.num = args.cut
        if self.args.server_method == 'FedMITR':
            self.data_pool2 = ImagePool(root=self.save_dir2)
            self.data_iter2 = None

        # Scaling factors
        self.adv = adv
        self.oh = oh
        self.tv1 = tv1
        self.tv2 = tv2
        self.l2 = l2
        self.num_classes = num_classes

        # training configs
        self.device = device
        self.test_loader = test_loader

    def synthesize(self, targets=None, num_patches=197, prune_it=None, prune_ratio=None):
        attention_weights = None
        for idx in range(len(self.clients)):
            self.student.eval()
            self.teachers[idx].eval()
            best_cost = 1e6
            inputs = torch.randn(size=[self.synthesis_batch_size, *self.img_size], device=self.device).requires_grad_()
            if targets is None:
                targets = torch.randint(low=0, high=self.num_classes, size=(self.synthesis_batch_size,))
                targets = targets.sort()[0]  # sort for better visualization
            targets = targets.to(self.device)
            # pdb.set_trace()
            optimizer = torch.optim.Adam([inputs], self.lr_g, betas=(0.5, 0.99))

            best_inputs = inputs.data
            current_abs_index = torch.LongTensor(list(range(num_patches))).repeat(best_inputs.shape[0], 1)\
                .to(self.device)
            next_relative_index = torch.LongTensor(list(range(num_patches))).repeat(best_inputs.shape[0], 1)\
                .to(self.device)

            for it in tqdm(range(self.iterations), file=sys.stdout):
                #  perform inversion stopping at t1, t2, ..., tn
                if it+1 in prune_it:
                    inputs_aug = inputs
                    current_abs_index_aug = current_abs_index
                    t_out, attention_weights, _ = self.teachers[idx](inputs_aug, current_abs_index_aug,
                                                                     next_relative_index)
                elif it in prune_it:
                    # (B,heads,N,N)->(B,p-1)
                    attention_weights = torch.mean(attention_weights[-1], dim=1)[:, 0, :][:, 1:]
                    prune_ratio_value = prune_ratio[prune_it.index(it)]
                    top_K=int(attention_weights.shape[1] * (1.0 - prune_ratio_value))
                    next_relative_index=get_top_k_relative_indices_including_first(pre_attention=attention_weights,
                                                                                   K=top_K).to(self.device)
                    inputs_aug = inputs
                    current_abs_index_aug = current_abs_index
                    t_out, attention_weights, current_abs_index = self.teachers[idx](inputs_aug, current_abs_index_aug,
                                                                                     next_relative_index)
                else:
                    inputs_aug, off1, off2, flip = jitter_and_flip(inputs)
                    if current_abs_index.shape[1]==num_patches:
                        current_abs_index_aug = current_abs_index
                    else:
                        current_abs_index_aug = jitter_and_flip_index(current_abs_index, off1, off2, flip,
                                                                     self.patch_size, int(224//self.patch_size))
                    t_out, attention_weights, _ = self.teachers[idx](inputs_aug, current_abs_index_aug,
                                                                     next_relative_index)

                loss_oh = F.cross_entropy(t_out, targets)
                s_out, _, _ = self.student(inputs_aug, current_abs_index_aug, next_relative_index)
                loss_adv = -jsdiv(s_out, t_out, T=3)
                loss_tv1, loss_tv2 = get_image_prior_losses(inputs)
                loss_l2 = torch.norm(inputs, 2)
                loss = self.oh * loss_oh +  \
                    + self.tv1 * loss_tv1 + self.tv2*loss_tv2 + self.l2 * loss_l2 + self.adv * loss_adv

                if best_cost > loss.item():
                    best_cost = loss.item()
                    best_inputs = inputs.data

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                inputs.data = clip_images(inputs.data, self.normalizer.mean, self.normalizer.std)
            self.student.train()
            if self.normalizer:
                best_inputs = self.normalizer(best_inputs, True)
            if len(prune_ratio)==1 and prune_ratio[0]==0:  # add non-masked image
                self.data_pool.add(best_inputs, targets)

            with torch.no_grad():
                t_out, attention_weights, current_abs_index = self.teachers[idx](best_inputs.detach(),
                                                                           torch.LongTensor(list(range(num_patches))).repeat(best_inputs.shape[0], 1).to(self.device),
                                                                           torch.LongTensor(list(range(num_patches))).repeat(best_inputs.shape[0], 1).to(self.device))

            attention_weights = torch.mean(attention_weights[-1], dim=1)[:, 0, :][:, 1:]  # (B,heads,N,N)->(B,p-1)

            def cumulative_mul(lst):
                current_mul = 1
                for num in lst:
                    current_mul = current_mul*(1.-num)
                return current_mul

            top_K = int(num_patches*(cumulative_mul(prune_ratio)))

            next_relative_index = get_top_k_relative_indices_including_first(pre_attention=attention_weights, K=top_K).to(self.device)

            mask = torch.zeros(next_relative_index.shape[0], int(sqrt(num_patches)), int(sqrt(num_patches)))
            for b in range(next_relative_index.shape[0]):
                mask[b, (next_relative_index[b][1:] - 1) // int(sqrt(num_patches)), (next_relative_index[b][1:] - 1) % int(sqrt(num_patches))] = 1
            expanded_mask = mask.repeat_interleave(self.patch_size, dim=1).repeat_interleave(self.patch_size, dim=2)
            expanded_mask = expanded_mask.to(self.device)
            masked_best_inputs = best_inputs * expanded_mask.unsqueeze(1)

            if not (len(prune_ratio) == 1 and prune_ratio[0] == 0):  # add masked image
                self.data_pool.add(masked_best_inputs, targets)
            if self.args.server_method == 'FedMITR':
                ours_expanded_mask = 1-expanded_mask.unsqueeze(1)
                masked_best_inputs_ours = best_inputs * ours_expanded_mask
                self.data_pool2.add(masked_best_inputs_ours, targets)
                dst = self.data_pool.get_dataset_label(transform=self.transform)
                dst2 = self.data_pool2.get_dataset(transform=self.transform)
            else:
                dst = self.data_pool.get_dataset(transform=self.transform)

        if len(dst) > (self.synthesis_batch_size * self.args.n_clients * self.num):
            files = sorted(os.listdir(self.save_dir))
            files_to_delete = [os.path.join(self.save_dir, file) for file in
                               files[:self.synthesis_batch_size * self.args.n_clients]]
            for file_to_delete in files_to_delete:
                os.remove(file_to_delete)
        if self.args.server_method == 'FedMITR':
            if len(dst2) > (self.synthesis_batch_size * self.args.n_clients * self.num):
                files = sorted(os.listdir(self.save_dir2))
                files_to_delete = [os.path.join(self.save_dir2, file) for file in
                                   files[:self.synthesis_batch_size * self.args.n_clients]]
                for file_to_delete in files_to_delete:
                    os.remove(file_to_delete)
            dst = self.data_pool.get_dataset_label(transform=self.transform)
            dst2 = self.data_pool2.get_dataset_label(transform=self.transform)
            combined_dataset = ConcatVerticalDataset(dst, dst2)
            loader_ours = torch.utils.data.DataLoader(
                combined_dataset, batch_size=self.sample_batch_size, shuffle=True, num_workers=8, pin_memory=True)
        else:
            dst = self.data_pool.get_dataset(transform=self.transform)
        loader = torch.utils.data.DataLoader(
            dst, batch_size=self.sample_batch_size, shuffle=True, num_workers=8, pin_memory=True)
        if self.args.server_method == 'FedMITR':
            self.data_iter = loader_ours
        else:
            self.data_iter = loader

    def sample(self):
        return self.data_iter
