import logging
import traceback
import copy
import torch
import math
from torch import nn
import torch.nn.functional as F
import numpy as np
# import wandb
import time
import sys
from model_trainer import ModelTrainer
def collect_activations(net,train_data_loader, device, orth_set,dataset): #distributed
# 
    layer_names = ['conv1.weight', 'layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.1.conv1.weight', \
    'layer1.1.conv2.weight', 'layer2.0.conv1.weight', 'layer2.0.conv2.weight', 'layer2.0.shortcut.0.weight', \
        'layer2.1.conv1.weight', 'layer2.1.conv2.weight', 'layer3.0.conv1.weight', 'layer3.0.conv2.weight', \
            'layer3.0.shortcut.0.weight', 'layer3.1.conv1.weight', 'layer3.1.conv2.weight', 'layer4.0.conv1.weight', \
                'layer4.0.conv2.weight', 'layer4.0.shortcut.0.weight', 'layer4.1.conv1.weight', 'layer4.1.conv2.weight']
#     
    stride_list = [2, 1,1,1,1, 2,1,2,1,1, 2,1,2,1,1, 2,1,2,1,1]   
# 
    map_list    = [16, 8,8,8,8, 8,4,8,4,4, 4,2,4,2,2, 2,1,2,1,1] 
    in_channel  = [ 3, 64,64,64,64, 64,128,64,128,128, 128,256,128,256,256, 256,512,256,512,512] 
    if dataset =="digit10":
        in_channel  = [ 1, 64,64,64,64, 64,128,64,128,128, 128,256,128,256,256, 256,512,256,512,512] 
        map_list    = [14, 7,7,7,7, 7,4,7,4,4, 4,2,4,2,2, 2,1,2,1,1] 
    elif dataset =="office31" or dataset =="office31_ca":
        map_list    = [112, 56,56,56,56, 56,28,56,28,28, 28,14,28,14,14, 14,7,14,7,7] 
    elif dataset =="tiny_imagenet":
        map_list    = [32, 16,16,16,16, 16,8,16,8,8, 8,4,8,4,4, 4,2,4,2,2] 



# 
    net.to(device)
    net.eval()
    activation = {}
    for key in layer_names:
        activation[key] = []
    total = 0
    for batch_index, (x, _) in enumerate(train_data_loader):
        total += _.size(0)
        # if batch_index > 2: break
        _ = net(x.to(device))
        # _ = net(x)

        act_list = [net.act['conv_in'], 
            net.layer1[0].act['conv_0'], net.layer1[0].act['conv_1'], net.layer1[1].act['conv_0'], net.layer1[1].act['conv_1'],
            net.layer2[0].act['conv_0'], net.layer2[0].act['conv_1'], net.layer2[0].act['conv_0'], net.layer2[1].act['conv_0'], net.layer2[1].act['conv_1'],
            net.layer3[0].act['conv_0'], net.layer3[0].act['conv_1'], net.layer3[0].act['conv_0'], net.layer3[1].act['conv_0'], net.layer3[1].act['conv_1'],
            net.layer4[0].act['conv_0'], net.layer4[0].act['conv_1'], net.layer4[0].act['conv_0'], net.layer4[1].act['conv_0'], net.layer4[1].act['conv_1']]
        for j, key in enumerate(layer_names):
            activation[key].append(act_list[j].detach().cpu())
    for name in activation.keys():
        activation[name] = torch.cat(activation[name],dim=0)
        if "shortcut" not in name:
            activation[name] = F.pad(activation[name], (1, 1, 1, 1), "constant", 0)

    # bsz = 2*64
    # bsz = len(train_data_loader.dataset)#全部数据
    bsz =total
    for i in range(len(stride_list)):
        layer_name = layer_names[i]
        k=0
        ksz= 3  #resnet conv3x3
        pad=1
        st = stride_list[i]
        act = activation[layer_name]
        if "shortcut" in layer_name:
            ksz=1
            pad=0
        if i ==0 :
            ksz = 7
            pad = 3
# 
        s=net.compute_conv_output_size(map_list[i],ksz,stride=stride_list[i],padding=pad)
        mat = torch.zeros((ksz*ksz*in_channel[i],s*s*bsz))
        for kk in range(bsz):
            for ii in range(s):
                for jj in range(s):
                    mat[:,k]=act[kk,:,st*ii:ksz+st*ii,st*jj:ksz+st*jj].reshape(-1) #take each vector 
                    k +=1
# 
        mat = mat.to(device)
        ratio = 1
        if orth_set[layer_name] is not None:
            U = orth_set[layer_name].to(device)
            projected = U @ U.T @ mat
            remaining = mat - projected
            rem_norm = torch.norm(remaining)
            orj_norm = torch.norm(mat)
            ratio = (rem_norm / orj_norm).cpu()
            mat = remaining
        activation[layer_name] = [(mat @ (torch.normal(0, 1, size=(mat.shape[1], mat.shape[0] * 5))).to(device)).cpu(), ratio, bsz]
        # activation[layer_name] = [mat @ (torch.normal(0, 1, size=(mat.shape[1], mat.shape[0] * 5))), ratio, bsz]

        # 对于每个层，它首先生成一个随机矩阵，并与mat相乘，以进行随机投影。然后，它将投影后的激活值、正交性比率和batch大小存储为一个元组，并替换activation字典中该层的原始条目。
    return activation
# # alexnet
# def collect_activations(model,train_data_loader, device, orth_set):
#     start = time.time()
#     model.to(device)
#     model.eval()
#     activation = {}
#     layer_names = ['conv1.weight', 'conv2.weight', 'conv3.weight', 'fc1.weight', 'fc2.weight']
#     for key in layer_names:
#         activation[key] = []
#     total = 0
#     for batch_index, (x, y) in enumerate(train_data_loader):
#         total += y.size(0)
#         _ = model(x.to(device))
#         for key in model.act.keys():
#             activation[key].append(model.act[key].detach().cpu())
#     for name in activation.keys():
#         activation[name] = torch.cat(activation[name],dim=0)
    
    
#     act_key = list(activation.keys())
#     # bsz = len(train_data_loader.dataset)
#     bsz = total
#     for i in range(len(model.map)):
#         k=0
#         if i<3: #conv layers
#             ksz= model.ksize[i]
#             act = activation[act_key[i]]
#             # print(act.device)
            
#             unfolder = torch.nn.Unfold(ksz, dilation=1, padding=0, stride= 1)
#             mat = unfolder(act.to(device))
#             mat = mat.permute(0,2,1)
#             mat = mat.reshape(-1, mat.shape[2])
#             mat = mat.T

#             mat = mat.to(device)
    
#             ratio = 1
#             if orth_set[act_key[i]] is not None:
#                 U = orth_set[act_key[i]].to(device)
#                 projected = U @ U.T @ mat
#                 remaining = mat - projected
#                 rem_norm = torch.norm(remaining)
#                 orj_norm = torch.norm(mat)
#                 ratio = (rem_norm / orj_norm).cpu()
#                 mat = remaining
#             activation[act_key[i]] = [(mat @ (torch.normal(0, 1, size=(mat.shape[1], mat.shape[0]))).to(device)).cpu(), ratio, bsz]
#         else:
#             mat = activation[act_key[i]].T.to(device)
#             ratio = 1
#             if orth_set[act_key[i]] is not None:
#                 U = orth_set[act_key[i]].to(device)
#                 projected = U @ U.T @ mat
#                 remaining = mat - projected
#                 rem_norm = torch.norm(remaining)
#                 orj_norm = torch.norm(mat)
#                 ratio = (rem_norm / orj_norm).cpu()
#                 mat = remaining
#             activation[act_key[i]] = [(mat @ (torch.normal(0, 1, size=(mat.shape[1], mat.shape[0] * 5))).to(device)).cpu(), ratio, bsz]
#     end = time.time()
#     print(f'Activations collection time {end-start}')
#     print(f'Activations size: {sys.getsizeof(activation)}')
    
#     return activation


class MyModelTrainer(ModelTrainer):
    def __init__(self, model,args=None):
        self.model = model
        self.model_p = copy.deepcopy(model)
        self.args = args

    def get_model_params(self):
        return self.model.cpu().state_dict()

    def get_model_name_params(self):#返回的是迭代器，不是list
        return self.model.named_parameters()

    def set_model_params(self, model_parameters):
        self.model.load_state_dict(model_parameters)

    def get_activations(self, train_data, device, orth_set,dataset):
        return collect_activations(self.model,train_data,device, orth_set,dataset)

    def train(self, train_data, device, args ,ta_id,used_B):
        # logging.debug("-------model actually train------")
        try:
            task_out = [ta_id * 2, (ta_id+1)*2]
            model = self.model
            model.to(device)

            current_B=used_B
            # num = args.B/args.batch_size
            num = (args.B / args.batch_size)* args.epochs * args.incremental_round

            num = math.floor(num)  

            model.train()
            criterion = nn.CrossEntropyLoss().to(device)
            # if args.client_optimizer == "sgd":
            #     optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr)
            # else:
            #     optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=5e-4)
            params_f = filter(lambda p: p.requires_grad, model.parameters())  
            
            # 创建一个 Adam 优化器  
            optimizer = torch.optim.Adam(params_f, lr=args.lr, betas=(0.9, 0.999))

            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode="min",factor=0.2,patience=2)
            epoch_loss = []
            if len(train_data[-1][0]) == 1:
                tt =  train_data[:-1]
                train_data = tt 

            for epoch in range(args.epochs):
                batch_loss = []
                for batch_idx, (x, labels) in enumerate(train_data):
                    x, labels = x.to(device), labels.to(device)
                    model.zero_grad()
                    log_probs = model(x)#[ta_id]
                    # labels = labels%2
                    #下面这两行，注释掉是train ×，保留是train √
                    # log_probs = log_probs[:,task_out[0]:task_out[1]]
                    # labels = labels - task_out[0]
                    loss = criterion(log_probs, labels)
                    loss.backward()
                    current_B+=1
                    optimizer.step() 
                    batch_loss.append(loss.item())
                    if current_B >= num:
                        break
                  # print('cu_epoch_loss',sum(batch_loss) / len(batch_loss))
                epoch_loss.append(sum(batch_loss) / len(batch_loss))
                scheduler.step(sum(batch_loss) / len(batch_loss))
                if (epoch == args.epochs -1) or (current_B >= num):
                    print('Client Index = {}\tEpoch: {}\tLoss: {:10f}'.format(
                            self.id, epoch, sum(batch_loss) / len(batch_loss)))
                if current_B >= num:
                    print('--------------------B has been used in '+str(epoch)+' epoch '+str(len(batch_loss))+' batch-----------')
                    break
            return current_B
        except Exception as e:
            logging.error(traceback.format_exc())
    def test(self, test_data, device, args):

        def calculate_top_k_accuracy(logits, targets, k=1):
            correct = 0
            values, indices = torch.topk(logits, k=k, sorted=True)
            for i in range(len(targets)):
                if targets[i] in indices[i]:
                    correct += 1
            return correct
        
        model = self.model
        model.to(device)

        model.eval()

        metrics = {
            'test_correct': 0,
            'test_loss': 0,
            'test_total': 0
        }

        criterion = nn.CrossEntropyLoss().to(device)

        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_data):
                x = x.to(device)
                target = target.to(device)
                pred = model(x)

                loss = criterion(pred, target)

                _, predicted = torch.max(pred, -1)
                correct = predicted.eq(target).sum()
                metrics['test_correct'] += correct.item()

                # metrics['test_correct'] += calculate_top_k_accuracy(pred,target,10)
                metrics['test_loss'] += loss.item() * target.size(0)
                metrics['test_total'] += target.size(0)
            print(metrics['test_total'])
        return metrics
    
    def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool:
        return False

