import logging
import sys
import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
import copy


def hetero_average_weights_subnet(track_channel_sync,w,record=False):
    # compute the weight of different clients for different channel

    overall_updates = {key:0 for key in track_channel_sync[0].keys()}
    max_key = max(track_channel_sync[0].keys())
    for client_track_channel in track_channel_sync:
        for key in client_track_channel.keys():
            for key_larger in range(key+1,max_key+1):
                client_track_channel[key] += client_track_channel[key_larger]
            overall_updates[key] += client_track_channel[key]
    for client_track_channel in track_channel_sync:
        for key in client_track_channel.keys():
            try:
                client_track_channel[key] = client_track_channel[key]/overall_updates[key]
            except:
                client_track_channel[key] = 1
    # print(track_channel_sync)
    # exit()

    w_avg = copy.deepcopy(w[0].state_dict())
    for channel_index in range(max_key):
        for key in w_avg.keys():
            if 'bias' in key:
                w_avg[key][channel_index] = w_avg[key][channel_index]*track_channel_sync[0][channel_index+1]
            if 'weight' in key:
                w_avg[key][channel_index,:,:,:] = w_avg[key][channel_index,:,:,:]*track_channel_sync[0][channel_index+1]

    for client_id in range(1,len(track_channel_sync)):
        for channel_index in range(max_key):
            for key in w_avg.keys():
                if 'bias' in key:
                    w_avg[key][channel_index] += w[client_id].state_dict()[key][channel_index]*track_channel_sync[client_id][channel_index+1]
                if 'weight' in key:
                    w_avg[key][channel_index,:,:,:] += w[client_id].state_dict()[key][channel_index,:,:,:]*track_channel_sync[client_id][channel_index+1]
   
    return w_avg



def average_weights(track_channel_sync,w,record=False):
    """
    Returns the average of the weights.
    """
    overall_updates = [0 for _ in range(len(track_channel_sync))]
    for client_id, client_track_channel in enumerate(track_channel_sync):
        for key in client_track_channel.keys():
            overall_updates[client_id] += client_track_channel[key]
    
    overall = sum(overall_updates)
    overall_updates = [x/overall for x in overall_updates]    


    w_avg = copy.deepcopy(w[0].state_dict())
    for key in w_avg.keys():
        w_avg[key] = overall_updates[0]*w_avg[key]

    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] +=  overall_updates[i]*w[i].state_dict()[key]
    
    return w_avg

def setup_logger(name, log_file, level=logging.INFO, console_out = True):
    """To setup as many loggers as you want"""

    handler = logging.FileHandler(log_file, mode='a')
    handler.setFormatter(formatter)

    logger = logging.getLogger(name)
    logger.setLevel(level)
    while logger.hasHandlers():
        logger.removeHandler(logger.handlers[0])
    logger.addHandler(handler)
    if console_out:
        stdout_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stdout_handler)
    return logger

def accuracy(output, target, topk=(1,), compress_V4shadowlabel = False, num_client = 10):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    if compress_V4shadowlabel:
        pred = pred % num_client
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class CrossEntropyLoss_BDKS_median(torch.nn.modules.loss._Loss):
    def forward(self, output, target,soft_label_wide,soft_label_narrow,client_label_size,channel_label_size,overall_samples,logit_16,tau=1,threshold=0.1,weight_constant=1):
        ''' higher weight means big see more, and more on logit cal '''
        weight=torch.tensor(channel_label_size/overall_samples).cuda()
        majority = (weight>0.1)
        minority = ~majority

        for output_item,label in zip(output,target):
            logit_16[label.item()].append(F.softmax(output_item.clone()).detach().cpu()[label.item()])

        ''' compute the kd from wide BL output logits '''
        soft_label_wide = soft_label_wide.type(torch.cuda.FloatTensor)
        target_modified = soft_label_wide.detach()
        # print('CrossEntropyLossSoft_bidirection','soft_label_modified', soft_label)
        # print('CrossEntropyLossSoft_bidirection','target_modified', target_modified[-1])
        output_log_prob = F.log_softmax(output, dim=1)

        target_modified = target_modified.unsqueeze(1)
        output_log_prob = output_log_prob.unsqueeze(2)
        W2Me_loss = -torch.bmm(target_modified, output_log_prob)
        W2Me_loss = torch.squeeze(W2Me_loss)


        ''' compute the narrow to median knowledge distillation '''
        target_smallbig = soft_label_narrow.type(torch.cuda.FloatTensor)
        target_smallbig = target_smallbig.detach()
        # print('target_smallbig',target_smallbig[0,:])
        target_smallbig = target_smallbig
        # print('target_smallbig*weight',target_smallbig[0,:])
        target_smallbig = target_smallbig.type(torch.cuda.FloatTensor)


        output_smallbig_prob = F.softmax(output, dim=1)
        output_smallbig_prob = torch.cat((torch.sum(output_smallbig_prob[:,majority],1).unsqueeze(1),output_smallbig_prob[:,minority]),1)
        target_smallbig = torch.cat((torch.sum(target_smallbig[:,majority],1).unsqueeze(1),target_smallbig[:,minority]),1)
        output_smallbig_log_prob = torch.log(output_smallbig_prob)
        # output_smallbig_log_prob = F.log_softmax(output, dim=1)
        # print('output_smallbig_log_prob',output_smallbig_log_prob[0,:])

        target_smallbig = target_smallbig.unsqueeze(1)
        output_smallbig_log_prob = output_smallbig_log_prob.unsqueeze(2)
        smallbig_loss = torch.squeeze(-torch.bmm(target_smallbig, output_smallbig_log_prob))
        # print('cross_entropy_loss',torch.squeeze(smallbig_loss))
        # exit()

        # print(output_smallbig_log_prob)
        # print(target_smallbig)
        if sum(weight<=threshold) == 0:
            cross_entropy_loss =  W2Me_loss 
        else:
            cross_entropy_loss = W2Me_loss + weight_constant*smallbig_loss


        return cross_entropy_loss




class CrossEntropyLoss_BDKS_high(torch.nn.modules.loss._Loss):
    def forward(self, output, target,soft_label,client_label_size,channel_label_size,overall_samples,logit_16,tau=1,threshold=0.1,weight_constant=1):
        ''' higher weight means big see more, and more on logit cal '''
        weight=torch.tensor(channel_label_size/overall_samples).cuda()
        majority = (weight>0.1)
        minority = ~majority


        for soft_label_item, output_item,label in zip(soft_label,output,target):
            logit_16[label.item()].append(F.softmax(output_item.clone()).detach().cpu()[label.item()])

        ''' compute the logit calibration output logits '''
        label_size_caled = []
        for x in client_label_size:
            if x!=0:
                label_size_caled.append(x)
            else:
                label_size_caled.append(1)
        label_size_caled = torch.tensor(label_size_caled).cuda()
        # print('logit_cal label_size',label_size)
        # print('logit_cal label_size_caled',label_size_caled)
        output_cal = output - (tau * (label_size_caled)**(-1/4))[None,:]
        output_cal = output_cal.type(torch.cuda.FloatTensor)
        output_cal_log_prob = F.log_softmax(output_cal, dim=1)
        # print('output_cal_log_prob',output_cal_log_prob[0,:])
        # print('logit_cal output_cal',F.softmax(output_cal)[0,:].detach().cpu().numpy().tolist())
        target_cal = F.one_hot(target,output_cal.size(1))
        # print(target_cal.size())
        # print('target_cal',target_cal[0,:])   
        target_cal = target_cal.type(torch.cuda.FloatTensor)
        # print('target_cal*weight',target_cal[0,:])

        target_cal = target_cal.unsqueeze(1)
        output_cal_log_prob = output_cal_log_prob.unsqueeze(2)
        logit_cal_loss = torch.squeeze(-torch.bmm(target_cal, output_cal_log_prob))
        # print('-torch.bmm(target_cal*weight, output_cal_log_prob)',(logit_cal_loss) )
        # exit(0)

        ''' compute the small to big knowledge distillation '''
        target_smallbig = soft_label.type(torch.cuda.FloatTensor)
        target_smallbig = target_smallbig.detach()
        # print('target_smallbig',target_smallbig[0,:])
        target_smallbig = target_smallbig
        # print('target_smallbig*weight',target_smallbig[0,:])
        target_smallbig = target_smallbig.type(torch.cuda.FloatTensor)


        output_smallbig_prob = F.softmax(output, dim=1)
        output_smallbig_prob = torch.cat((torch.sum(output_smallbig_prob[:,majority],1).unsqueeze(1),output_smallbig_prob[:,minority]),1)
        target_smallbig = torch.cat((torch.sum(target_smallbig[:,majority],1).unsqueeze(1),target_smallbig[:,minority]),1)
        output_smallbig_log_prob = torch.log(output_smallbig_prob)
        # output_smallbig_log_prob = F.log_softmax(output, dim=1)
        # print('output_smallbig_log_prob',output_smallbig_log_prob[0,:])

        target_smallbig = target_smallbig.unsqueeze(1)
        output_smallbig_log_prob = output_smallbig_log_prob.unsqueeze(2)
        smallbig_loss = torch.squeeze(-torch.bmm(target_smallbig, output_smallbig_log_prob))
        # print('cross_entropy_loss',torch.squeeze(smallbig_loss))
        # exit()

        # print(output_smallbig_log_prob)
        # print(target_smallbig)
        if sum(weight<=threshold) == 0:
            cross_entropy_loss =  logit_cal_loss 
        else:
            cross_entropy_loss = logit_cal_loss + weight_constant*smallbig_loss


        return cross_entropy_loss



class CrossEntropyLossCal(torch.nn.modules.loss._Loss):
    def __init__(self,slow_channel,fast_channel):
        super(CrossEntropyLossCal, self).__init__()
        self.slow_channel = slow_channel
        self.fast_channel = fast_channel

    """ logits calibration for image classification """
    def forward(self, output, target,label_size,logit_1_part,logit_16, channel,high_end_client, tau=1):
        label_size_caled = []
        for x in label_size:
            if x!=0:
                label_size_caled.append(x)
            else:
                label_size_caled.append(1)
        label_size_caled = torch.tensor(label_size_caled).cuda()
        # print('logit_cal label_size',label_size)
        # print('logit_cal label_size_caled',label_size_caled)
        # print('logit_cal output',F.softmax(output))
        for output_item,label in zip(output,target):
            if channel == self.slow_channel:
                if high_end_client:
                    # print('exit here')
                    # exit()
                    logit_1_part[label.item()].append(F.softmax(output_item.clone()).detach().cpu()[label.item()])

            elif channel == self.fast_channel:
                logit_16[label.item()].append(F.softmax(output_item.clone()).detach().cpu()[label.item()])

        output_cal = output - (tau * (label_size_caled)**(-1/4))[None,:]
        output_cal = output_cal.type(torch.cuda.FloatTensor)
        # print('logit_cal output_cal',F.softmax(output_cal))


        # TODO: make target the same shape as output
        target = F.one_hot(target,output_cal.size(1))
        target = target.type(torch.cuda.FloatTensor)

        output_log_prob = F.log_softmax(output_cal, dim=1)
        target = target.unsqueeze(1)
        output_log_prob = output_log_prob.unsqueeze(2)
        cross_entropy_loss = -torch.bmm(target, output_log_prob)
        cross_entropy_loss = torch.squeeze(cross_entropy_loss)

        # print(label_size)

        return cross_entropy_loss




class CrossEntropyLossCal_3sets(torch.nn.modules.loss._Loss):
    def __init__(self):
        super(CrossEntropyLossCal_3sets, self).__init__()


    """ logits calibration for image classification """
    def forward(self, output, target,label_size,logit_record,tau=1):
        label_size_caled = []
        for x in label_size:
            if x!=0:
                label_size_caled.append(x)
            else:
                label_size_caled.append(1)
        label_size_caled = torch.tensor(label_size_caled).cuda()
        # print('logit_cal label_size',label_size)
        # print('logit_cal label_size_caled',label_size_caled)
        # print('logit_cal output',F.softmax(output))
        for output_item,label in zip(output,target):
            logit_record[label.item()].append(F.softmax(output_item.clone()).detach().cpu()[label.item()])

        output_cal = output - (tau * (label_size_caled)**(-1/4))[None,:]
        output_cal = output_cal.type(torch.cuda.FloatTensor)
        # print('logit_cal output_cal',F.softmax(output_cal))


        # TODO: make target the same shape as output
        target = F.one_hot(target,output_cal.size(1))
        target = target.type(torch.cuda.FloatTensor)

        output_log_prob = F.log_softmax(output_cal, dim=1)
        target = target.unsqueeze(1)
        output_log_prob = output_log_prob.unsqueeze(2)
        cross_entropy_loss = -torch.bmm(target, output_log_prob)
        cross_entropy_loss = torch.squeeze(cross_entropy_loss)

        # print(label_size)

        return cross_entropy_loss







class CrossEntropyLosssmalltobig(torch.nn.modules.loss._Loss):
    def forward(self, output, target,soft_label,logit_16):
        for output_item,label in zip(output,target):
            logit_16[label.item()].append(F.softmax(output_item.clone()).detach().cpu()[label.item()])


        soft_label = soft_label.type(torch.cuda.FloatTensor)
        target_modified = soft_label.detach()
        # print('CrossEntropyLossSoft_bidirection','soft_label_modified', soft_label)
        # print('CrossEntropyLossSoft_bidirection','target_modified', target_modified[-1])
        output_log_prob = F.log_softmax(output, dim=1)

        target_modified = target_modified.unsqueeze(1)
        output_log_prob = output_log_prob.unsqueeze(2)
        cross_entropy_loss = -torch.bmm(target_modified, output_log_prob)
        cross_entropy_loss = torch.squeeze(cross_entropy_loss)
        return cross_entropy_loss




class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
