import torch
from tqdm import tqdm
import sys
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def kd_train(epoch, calibrate_dataloader, teacher, global_model, criterion_kd, optimizer_kd, device):
    bar = tqdm(calibrate_dataloader, file=sys.stdout)
    correct = 0.0
    for idx, data in enumerate(bar):
        data = data.to(device)
        with torch.no_grad():
            output_t = teacher(data,
                                     torch.arange(197).repeat(data.shape[0], 1).to(data.device),
                                     torch.cat([torch.zeros(data.shape[0], 1, dtype=torch.long)
                                               .to(data.device),
                                                find_non_zero_patches(images=data, patch_size=16)],
                                               dim=1))

        aaa = torch.cat([torch.zeros(data.shape[0], 1, dtype=torch.long).to(data.device),
                         find_non_zero_patches(images=data, patch_size=16)], dim=1)

        output_s = global_model(data,
                                torch.arange(197).repeat(data.shape[0], 1)
                                .to(data.device), aaa)

        loss_cal = criterion_kd(output_s[0], output_t[0].detach())
        optimizer_kd.zero_grad()
        loss_cal.backward()
        optimizer_kd.step()

        pred = output_s[0].argmax(dim=1)
        target = output_t[0].argmax(dim=1)
        correct += pred.eq(target.view_as(pred)).sum().item()
    acc = correct / len(calibrate_dataloader.dataset)

    print(f"Epoch {epoch + 1} kd_feddf_acc:{acc:.4f} ")
    logger.info(f"Epoch {epoch + 1} kd_feddf_acc:{acc:.4f} ")


def find_non_zero_patches(images, patch_size):
    bs, c, h, w = images.shape
    patch_h, patch_w = patch_size, patch_size
    if h % patch_h != 0 or w % patch_w != 0:
        raise ValueError("Image dimensions are not divisible by patch size")

    images_reshaped = images.reshape(bs, c, h // patch_h, patch_h, w // patch_w, patch_w)

    images_transposed = images_reshaped.permute(0, 2, 4, 1, 3, 5)

    images_patches = images_transposed.reshape(bs, -1, c * patch_h * patch_w)

    non_zero_patches = torch.any(images_patches != 0, dim=2)

    non_zero_indices = [torch.nonzero(non_zero_patches[i], as_tuple=False).squeeze() + 1 for i in range(bs)]
    non_zero_indices=torch.stack(non_zero_indices, dim=0)
    return non_zero_indices
