import pdb
import random
import torch.backends.cudnn as cudnn
import numpy as np
import torch
import logging
import math
import time
import copy
from torch.nn import functional as F
import torch.optim as optim
import sys
import torch.nn as nn
from torch.autograd import Variable
from tqdm import tqdm
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def set_random_seed(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    cudnn.deterministic = True


def get_log_name(args):
    if args.type == 'train':
        log_name = (f'{args.type}_seed{args.seed}_{args.dataset}_{args.local_method}_{args.model}'
                    f'_lr{args.lr}_bs{args.bs}_le{args.local_epochs}_n{args.n_clients}_{args.partition}')
    else:
        log_name = (f'{args.type}_seed{args.seed}_{args.dataset}_{args.server_method}_{args.model}'
                    f'_lr{args.lr}_bs{args.bs}_le{args.local_epochs}_n{args.n_clients}_{args.partition}')
    if args.partition == 'Dirichlet':
        log_name += f'{args.beta}'

    log_name += (f'_beta1_{args.alpha1}'
                 f'_beta2_{args.alpha2}_ratio{args.prune_ratio}')

    return log_name


def get_model_name(args):
    model_name = f'seed{args.seed}_{args.dataset}_lr{args.lr}_bs{args.bs}_{args.model}' \
                     f'_le{args.local_epochs}_n{args.n_clients}_{args.partition}'

    if args.partition == 'Dirichlet':
        model_name += f'{args.beta}'
    if args.pretrained:
        model_name += '_pretrained'

    return model_name


def compute_accuracy(model, dataloader, info, device):
    model.to(device)
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        bar = tqdm(dataloader, file=sys.stdout)
        for batch_idx, (x, target) in enumerate(bar):
            x, target = x.to(device), target.to(device, dtype=torch.int64)
            output = model(x, torch.LongTensor(list(range(197))).repeat(x.shape[0], 1).to(device),
                           torch.LongTensor(list(range(197))).repeat(x.shape[0], 1).to(device))
            _, pred_label = torch.max(output[0].data, 1)

            total += x.data.size()[0]
            correct += (pred_label == target.data).sum().item()
            acc = correct / float(total)

        print(f'{info}: acc: {acc:.4f}')
        logger.info(f'{info}: acc: {acc:.4f}')
    return acc


class Ensemble(torch.nn.Module):
    def __init__(self, model_list):
        super(Ensemble, self).__init__()
        self.models = model_list

    def forward(self, x, y, z):
        logits_total = 0
        for model_id in range(len(self.models)):
            output = self.models[model_id](x, y, z)
            logits = output[0]

            logits_total += logits
        logits_avg = logits_total / len(self.models)

        return logits_avg, output[1], output[2]


class WEnsemble(torch.nn.Module):
    def __init__(self, model_list, mdl_w_list):
        super(WEnsemble, self).__init__()
        self.models = model_list
        self.mdl_w_list = mdl_w_list

    def forward(self, x, y, z):
        logits_avg = 0
        for model_id in range(len(self.models)):
            output = self.models[model_id](x, y, z)
            logits = output[0]
            logits_avg += self.mdl_w_list[model_id] * logits

        return logits_avg, output[1], output[2]


class Ensemble_soft(torch.nn.Module):
    def __init__(self, model_list):
        super(Ensemble_soft, self).__init__()
        self.models = model_list

    def forward(self, x, y, z):
        logits_total = 0
        for model_id in range(len(self.models)):
            output = self.models[model_id](x, y, z)
            logits = output[0]
            logits = torch.softmax(logits, dim=1)

            logits_total += logits
        logits_avg = logits_total / len(self.models)

        return logits_avg, output[1], output[2]


def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = w[0]
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        if 'num_batches_tracked' in key:
            w_avg[key] = w_avg[key].true_divide(len(w))
        else:
            w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg


def grad_False(model, select_frozen_layers=None):
    if select_frozen_layers==None:
        for name, param in model.named_parameters():
            param.requires_grad = False
    else:
        i = 0
        for name, param in model.named_parameters():
            if select_frozen_layers in model.named_parameter_layers[i]:
                param.requires_grad = False
            i += 1


def grad_True(model):
    for name, param in model.named_parameters():
        param.requires_grad = True


def find_non_zero_patches(images, patch_size):
    bs, c, h, w = images.shape
    patch_h, patch_w = patch_size, patch_size
    if h % patch_h != 0 or w % patch_w != 0:
        raise ValueError("Image dimensions are not divisible by patch size")

    images_reshaped = images.reshape(bs, c, h // patch_h, patch_h, w // patch_w, patch_w)

    images_transposed = images_reshaped.permute(0, 2, 4, 1, 3, 5)

    images_patches = images_transposed.reshape(bs, -1, c * patch_h * patch_w)

    non_zero_patches = torch.any(images_patches != 0, dim=2)

    non_zero_indices = [torch.nonzero(non_zero_patches[i], as_tuple=False).squeeze() + 1 for i in range(bs)]
    non_zero_indices=torch.stack(non_zero_indices, dim=0)
    return non_zero_indices