import torch
from models.get_model import get_model
from server.aggregation import *
import copy
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import ConcatDataset
import torch
from torch import nn
from tqdm import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader
import copy, wandb
from abc import ABC
from torch.autograd import Variable

from kornia import augmentation
import os
from PIL import Image
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list  # data_list: list of (image, label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx]
        return image, label

def evaluate_model(train_data, test_data, model, previous_test, args):
    train_set = [item for sublist in train_data for item in sublist]
    test_set = [item for item in test_data]
    model.eval()
    train_dataset = MyDataset(train_set)
    test_dataset = MyDataset(test_set)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
    total_correct_train, total_samples_train = 0, 0
    total_correct_test, total_samples_test = 0, 0

    with torch.no_grad():
        for images, labels in train_dataloader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            if args.method == 'FedCIL':
                outputs = model(images)[2]
            elif args.method == 'AFFCL':
                outputs = model(images)[0]
            else:
                outputs = model(images)
            # time.sleep(10)
            _, predicted = torch.max(outputs, 1)
            total_correct_train += (predicted == labels).sum().item()
            total_samples_train += labels.size(0)

        for images, labels in test_dataloader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            if args.method == 'FedCIL':
                outputs = model(images)[2]
            elif args.method == 'AFFCL':
                outputs = model(images)[0]
            else:
                outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total_correct_test += (predicted == labels).sum().item()
            total_samples_test += labels.size(0)

    avg_train_acc = total_correct_train / total_samples_train if total_samples_train > 0 else 0
    avg_test_acc = total_correct_test / total_samples_test if total_samples_test > 0 else 0

    ##forgetting
    a_t_i = []
    for task_id in range(len(previous_test)):
        val_dataloader_id = torch.utils.data.DataLoader(previous_test[task_id],
                                                        batch_size=args.batch_size, shuffle=False)
        total_correct_test_id, total_samples_test_id = 0, 0
        for images, labels in val_dataloader_id:
            images = images.to(args.device)
            labels = labels.to(args.device)
            if args.method == 'FedCIL':
                outputs = model(images)[2]
            elif args.method == 'AFFCL':
                outputs = model(images)[0]
            else:
                outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total_correct_test_id += (predicted == labels).sum().item()
            total_samples_test_id += labels.size(0)
        a_t_i.append(round(total_correct_test_id / total_samples_test_id if total_samples_test_id > 0 else 0, 4))
    all_test_acc = sum(a_t_i) / len(a_t_i)

    ##all previous test
    all_test_data = previous_test[:-1]
    if len(all_test_data) >= 1:
        all_test_dataset = ConcatDataset(all_test_data)
        all_test_loader = torch.utils.data.DataLoader(all_test_dataset, batch_size=args.batch_size, shuffle=False)
        total_correct_all_test, total_samples_all_test = 0, 0
        with torch.no_grad():
            for images, labels in all_test_loader:
                images = images.to(args.device)
                labels = labels.to(args.device)
                if args.method == 'FedCIL':
                    outputs = model(images)[2]
                elif args.method == 'AFFCL':
                    outputs = model(images)[0]
                else:
                    outputs = model(images)
                # time.sleep(10)
                _, predicted = torch.max(outputs, 1)
                total_correct_all_test += (predicted == labels).sum().item()
                total_samples_all_test += labels.size(0)
        all_previous_test_acc = total_correct_all_test / total_samples_all_test if total_samples_all_test > 0 else 0
    else:
        all_previous_test_acc = 0

    return avg_train_acc, avg_test_acc, a_t_i, all_previous_test_acc, all_test_acc
#
# def evaluate_model(train_data, val_data, test_data, model, previous_data, args):
#     train_set = [item for sublist in train_data for item in sublist]
#     validation_set = [item for item in val_data]
#     test_set = [item for item in test_data]
#     model.eval()
#     train_dataset = MyDataset(train_set)
#     val_dataset = MyDataset(validation_set)
#     test_dataset = MyDataset(test_set)
#     train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
#     val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
#     test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
#     total_correct_train, total_samples_train = 0, 0
#     total_correct_val, total_samples_val = 0, 0
#     total_correct_test, total_samples_test = 0, 0
#
#     with torch.no_grad():
#         for images, labels in train_dataloader:
#             images = images.to(args.device)
#             labels = labels.to(args.device)
#             if args.method == 'FedCIL':
#                 outputs = model(images)[2]
#             elif args.method == 'AFFCL':
#                 outputs = model(images)[0]
#             else:
#                 outputs = model(images)
#             # time.sleep(10)
#             _, predicted = torch.max(outputs, 1)
#             total_correct_train += (predicted == labels).sum().item()
#             total_samples_train += labels.size(0)
#
#         for images, labels in val_dataloader:
#             images = images.to(args.device)
#             labels = labels.to(args.device)
#             if args.method == 'FedCIL':
#                 outputs = model(images)[2]
#             elif args.method == 'AFFCL':
#                 outputs = model(images)[0]
#             else:
#                 outputs = model(images)
#             _, predicted = torch.max(outputs, 1)
#             total_correct_val += (predicted == labels).sum().item()
#             total_samples_val += labels.size(0)
#
#         for images, labels in test_dataloader:
#             images = images.to(args.device)
#             labels = labels.to(args.device)
#             if args.method == 'FedCIL':
#                 outputs = model(images)[2]
#             elif args.method == 'AFFCL':
#                 outputs = model(images)[0]
#             else:
#                 outputs = model(images)
#             _, predicted = torch.max(outputs, 1)
#             total_correct_test += (predicted == labels).sum().item()
#             total_samples_test += labels.size(0)
#
#     avg_train_acc = total_correct_train / total_samples_train if total_samples_train > 0 else 0
#     avg_val_acc = total_correct_val / total_samples_val if total_samples_val > 0 else 0
#     avg_test_acc = total_correct_test / total_samples_test if total_samples_test > 0 else 0
#
#     ##forgetting
#     a_t_i = []
#     for task_id in range(len(previous_data)):
#         val_dataloader_id = torch.utils.data.DataLoader(previous_data[task_id],
#                                                         batch_size=args.batch_size, shuffle=False)
#         total_correct_val_id, total_samples_val_id = 0, 0
#         for images, labels in val_dataloader_id:
#             images = images.to(args.device)
#             labels = labels.to(args.device)
#             if args.method == 'FedCIL':
#                 outputs = model(images)[2]
#             elif args.method == 'AFFCL':
#                 outputs = model(images)[0]
#             else:
#                 outputs = model(images)
#             _, predicted = torch.max(outputs, 1)
#             total_correct_val_id += (predicted == labels).sum().item()
#             total_samples_val_id += labels.size(0)
#         a_t_i.append(round(total_correct_val_id / total_samples_val_id if total_samples_val_id > 0 else 0, 4))
#
#     ##all val
#     all_val_data = previous_data[:-1]
#     if len(all_val_data) >= 1:
#         all_val_dataset = ConcatDataset(all_val_data)
#         all_val_loader = torch.utils.data.DataLoader(all_val_dataset, batch_size=args.batch_size, shuffle=False)
#         total_correct_all_val, total_samples_all_val = 0, 0
#         with torch.no_grad():
#             for images, labels in all_val_loader:
#                 images = images.to(args.device)
#                 labels = labels.to(args.device)
#                 if args.method == 'FedCIL':
#                     outputs = model(images)[2]
#                 elif args.method == 'AFFCL':
#                     outputs = model(images)[0]
#                 else:
#                     outputs = model(images)
#                 # time.sleep(10)
#                 _, predicted = torch.max(outputs, 1)
#                 total_correct_all_val += (predicted == labels).sum().item()
#                 total_samples_all_val += labels.size(0)
#         all_val_acc = total_correct_all_val / total_samples_all_val if total_samples_all_val > 0 else 0
#     else:
#         all_val_acc = 0
#
#     return avg_train_acc, avg_val_acc, avg_test_acc, a_t_i, all_val_acc


#### toolkits for TARGET:
## functions:
def save_image_batch(imgs, root, targets, idxx):
    os.makedirs(root, exist_ok=True)
    for idx, (img, label) in enumerate(zip(imgs, targets)):
        label = int(label)
        label_dir = os.path.join(root, f"class_{label}")
        os.makedirs(label_dir, exist_ok=True)  # 创建目录如果不存在
        if isinstance(img, torch.Tensor):
            img = (img.detach().clamp(0, 1).cpu().numpy() * 255).astype("uint8")
            img = img.transpose(1, 2, 0)  # CHW → HWC
        img_pil = Image.fromarray(img)

        filename = os.path.join(label_dir, f"{idxx}_{idx}.png")
        img_pil.save(filename)


def fomaml_grad(src, tar):
    for p, tar_p in zip(src.parameters(), tar.parameters()):
        if p.grad is None:
            p.grad = Variable(torch.zeros(p.size())).cuda()
        p.grad.data.add_(tar_p.grad.data)  #, alpha=0.67


def kd_train(student, teacher, criterion, optimizer, data_loader, kd_steps):
    student.train()
    teacher.eval()
    loader = data_loader
    for item in range(kd_steps):
        for images, labels in loader:
            images, labels = images.cuda(), labels.cuda()
            with torch.no_grad():
                t_out = teacher(images)
            s_out = student(images.detach())
            loss_s = criterion(s_out, t_out.detach())
            optimizer.zero_grad()
            loss_s.backward()
            optimizer.step()


def kldiv(logits, targets, T=1.0, reduction='batchmean'):
    q = F.log_softmax(logits / T, dim=1)
    p = F.softmax(targets / T, dim=1)
    return F.kl_div(q, p, reduction=reduction) * (T * T)


def KD_loss(pred, soft, T):
    pred = torch.log_softmax(pred / T, dim=1)
    soft = torch.softmax(soft / T, dim=1)
    return -1 * torch.mul(soft, pred).sum() / pred.shape[0]


def weight_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


## class:
class GlobalSynthesizer(ABC):
    def __init__(self, teacher, student, generator, nz, num_classes, img_size,
                 init_dataset=None, iterations=100, lr_g=0.0002,
                 synthesis_batch_size=256, sample_batch_size=128,
                 adv=1.0, bn=1, oh=1,
                 save_dir='run/fast', transform=None, autocast=None, use_fp16=False,
                 normalizer=None, distributed=False, lr_z=0.1,
                 warmup=10, reset_l0=0, reset_bn=0, bn_mmt=0.9,
                 is_maml=1, args=None):
        self.synthesis_batch_size = synthesis_batch_size
        self.teacher = teacher
        self.student = student
        self.lr_z = lr_z
        self.lr_g = lr_g
        self.iterations = iterations
        self.ep = 0
        self.ep_start = warmup
        self.bn = bn
        self.oh = oh
        self.nz = nz
        self.bn_mmt = bn_mmt
        self.num_classes = num_classes
        self.data_pool = ImagePool(root=save_dir)
        self.img_size = img_size

        self.hooks = []
        for m in teacher.modules():
            if isinstance(m, nn.BatchNorm2d):
                self.hooks.append(DeepInversionHook(m, self.bn_mmt))

        self.aug = transforms.Compose([
            augmentation.RandomCrop(size=[self.img_size[-2], self.img_size[-1]], padding=4),
            augmentation.RandomHorizontalFlip(),
            normalizer,
        ])
        self.transform = transform
        self.generator = generator.cuda().train()
        self.meta_optimizer = torch.optim.Adam(self.generator.parameters(), self.lr_g * self.iterations,
                                               betas=[0.5, 0.999])

        self.bn_mmt = bn_mmt
        self.adv = adv

    def synthesize(self, roundd):
        self.ep += 1
        self.student.eval()
        self.teacher.eval()
        best_cost = 1e6

        best_inputs = None
        z = torch.randn(size=(self.synthesis_batch_size, self.nz)).cuda()
        z.requires_grad = True
        targets = torch.randint(low=0, high=self.num_classes, size=(self.synthesis_batch_size,))
        targets = targets.cuda()

        fast_generator = self.generator.clone()
        optimizer = torch.optim.Adam([
            {'params': fast_generator.parameters()},
            {'params': [z], 'lr': self.lr_z}
        ], lr=self.lr_g, betas=[0.5, 0.999])

        for it in range(self.iterations):
            inputs = fast_generator(z)
            inputs_aug = self.aug(inputs)

            t_out = self.teacher(inputs_aug)

            # calculate loss
            loss_bn = sum([h.r_feature for h in self.hooks])
            loss_oh = F.cross_entropy(t_out, targets)
            if self.ep >= self.ep_start:
                s_out = self.student(inputs_aug)
                mask = (s_out.max(1)[1] == t_out.max(1)[1]).float()
                loss_adv = -(kldiv(s_out, t_out, reduction='none').sum(1) * mask).mean()
            else:
                loss_adv = loss_oh.new_zeros(1)
            loss = self.bn * loss_bn + self.oh * loss_oh + self.adv * loss_adv
            with torch.no_grad():
                if best_cost > loss.item() or best_inputs is None:
                    best_cost = loss.item()
                    best_inputs = inputs.data.cpu()

            optimizer.zero_grad()
            loss.backward()

            if it == 0:
                self.meta_optimizer.zero_grad()
            fomaml_grad(self.generator, fast_generator)
            if it == (self.iterations - 1):
                self.meta_optimizer.step()

            optimizer.step()

        for h in self.hooks:
            h.update_mmt()

        self.student.train()
        self.prev_z = (z, targets)
        if roundd >= 20:
            self.data_pool.add(best_inputs, targets.tolist())



class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, img_size=32, nc=3):
        super(Generator, self).__init__()
        self.params = (nz, ngf, img_size, nc)
        self.init_size = img_size // 4
        self.l1 = nn.Sequential(nn.Linear(nz, ngf * 2 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(ngf * 2),
            nn.Upsample(scale_factor=2),

            nn.Conv2d(ngf * 2, ngf * 2, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),

            nn.Conv2d(ngf * 2, ngf, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ngf, nc, 3, stride=1, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], -1, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

    # return a copy of its own
    def clone(self):
        clone = Generator(self.params[0], self.params[1], self.params[2], self.params[3])
        clone.load_state_dict(self.state_dict())
        return clone.cuda()


class DeepInversionHook():
    '''
    Implementation of the forward hook to track feature statistics and compute a loss on them.
    Will compute mean and variance, and will use l2 as a loss
    '''

    def __init__(self, module, mmt_rate):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.module = module
        self.mmt_rate = mmt_rate
        self.mmt = None
        self.tmp_val = None

    def hook_fn(self, module, input, output):
        # hook co compute deepinversion's feature distribution regularization
        nch = input[0].shape[1]
        mean = input[0].mean([0, 2, 3])
        var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
        # forcing mean and variance to match between two distributions
        # other ways might work better, i.g. KL divergence
        if self.mmt is None:
            r_feature = torch.norm(module.running_var.data - var, 2) + \
                        torch.norm(module.running_mean.data - mean, 2)
        else:
            mean_mmt, var_mmt = self.mmt
            r_feature = torch.norm(module.running_var.data - (1 - self.mmt_rate) * var - self.mmt_rate * var_mmt, 2) + \
                        torch.norm(module.running_mean.data - (1 - self.mmt_rate) * mean - self.mmt_rate * mean_mmt, 2)

        self.r_feature = r_feature
        self.tmp_val = (mean, var)

    def update_mmt(self):
        mean, var = self.tmp_val
        if self.mmt is None:
            self.mmt = (mean.data, var.data)
        else:
            mean_mmt, var_mmt = self.mmt
            self.mmt = (self.mmt_rate * mean_mmt + (1 - self.mmt_rate) * mean.data,
                        self.mmt_rate * var_mmt + (1 - self.mmt_rate) * var.data)


class ImagePool(object):
    def __init__(self, root):
        self.root = os.path.abspath(root)
        os.makedirs(self.root, exist_ok=True)
        self._idx = 0

    def add(self, imgs, targets=None):
        save_image_batch(imgs, self.root, targets, self._idx)
        self._idx += 1


class KLDiv(nn.Module):
    def __init__(self, T=1.0, reduction='batchmean'):
        super().__init__()
        self.T = T
        self.reduction = reduction

    def forward(self, logits, targets):
        return kldiv(logits, targets, T=self.T, reduction=self.reduction)
