from copy import deepcopy
import sys
import torch.nn as nn
from torch.optim import SGD


def new_train_iter(rank, criterion, w_optimizer, w_model, w_model_backup, 
    data_ratio_pairs, global_direction, args, device):

    w_model.train()

    # --- disable nomalization layers ---
    for module in w_model.modules():
        # print(module)
        if isinstance(module, nn.BatchNorm2d):
            if hasattr(module, 'weight'):
                module.weight.requires_grad_(False)
            if hasattr(module, 'bias'):
                module.bias.requires_grad_(False)
            module.eval()

    epoch_train_loss = 0
    epoch_batch_cnt = 0

    correct = 0
    total = 0
    
    last_model = deepcopy(w_model)
    last_optimizer = SGD(last_model.parameters(), lr=args.lr)
    last_model.train()
    # for p_idx, param in enumerate(w_model_backup.parameters()):
    #     param_tmp = deepcopy(param.data)
    #     last_model.append(param_tmp)
    learning_rate = args.lr
    param_grad = deepcopy(global_direction)

    total_batch = len(data_ratio_pairs[rank][0])
    
    if args.K == 1:
        maximum_steps = 640 / args.bsz
    elif args.K == 2:
        maximum_steps = 320 / args.bsz
    else:
        maximum_steps = 1280 / args.bsz

    print("lr", learning_rate, "bsz cnt", total_batch, "step", maximum_steps)
    
    completed_steps = 0
    while completed_steps < maximum_steps:
    
        for batch_idx, (data, target) in enumerate(data_ratio_pairs[rank][0]):
            
            if completed_steps == maximum_steps:
                break
            
            data, target = data.to(device), target.to(device)
            last_optimizer.zero_grad()
            last_output = last_model(data)
            last_loss = criterion(last_output, target)
            last_loss.backward()

            w_optimizer.zero_grad()
            output = w_model(data)
            loss = criterion(output, target)
            loss.backward()

            # param_grad=[]  #一个client的所有梯度
            for p_idx, param in enumerate(w_model.parameters()):
                
                param_grad[p_idx] = param_grad[p_idx] - list(last_model.parameters())[p_idx].grad.data.clone().detach() \
                    + list(w_model.parameters())[p_idx].grad.data.clone().detach()
                # param_grad[p_idx] = list(w_model.parameters())[p_idx].grad.data.clone().detach()
            
            # last_optimizer.zero_grad()
            # w_optimizer.zero_grad()

            last_model = deepcopy(w_model)
            last_optimizer = SGD(last_model.parameters(), lr=args.lr)
            
            for p_idx, param in enumerate(last_model.parameters()):
                param.data = list(w_model.parameters())[p_idx].data.clone().detach()
            
            for p_idx, param in enumerate(w_model.parameters()):
                param.data -= learning_rate * param_grad[p_idx]
            
            # w_optimizer.step(param_grad)
            # w_optimizer.step()

            # for p_idx, param in enumerate(model.parameters()):
            #     print(rank, batch_id, p_idx, "intraining", 
            #         param.data)

            # accuracy calculation
            epoch_train_loss += loss.data.item()
            epoch_batch_cnt += 1
            
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            completed_steps += 1
            # print(rank, batch_id, 'Acc: %.3f%% (%d/%d)'
            #       % (100. * correct / total, correct, total), loss.data.item(), time.time() - batch_start_time)
            sys.stdout.flush()

    delta_ws = []

    for p_idx, param in enumerate(w_model_backup.parameters()):
        delta_ws.append(param - list(w_model.parameters())[p_idx].data)

    return delta_ws, (epoch_train_loss / epoch_batch_cnt), correct / total

def compute_grad(rank, criterion, w_optimizer, w_model, w_model_backup, 
    data_ratio_pairs_full_batch, device):
    w_model.train()

    # --- disable nomalization layers ---
    for module in w_model.modules():
        # print(module)
        if isinstance(module, nn.BatchNorm2d):
            if hasattr(module, 'weight'):
                module.weight.requires_grad_(False)
            if hasattr(module, 'bias'):
                module.bias.requires_grad_(False)
            module.eval()

    epoch_train_loss = 0
    epoch_batch_cnt = 0

    correct = 0
    total = 0
    train_data = iter(data_ratio_pairs_full_batch[rank][0])
    data, target = next(train_data)

    data, target = data.to(device), target.to(device)
    w_optimizer.zero_grad()
    output = w_model(data)
    loss = criterion(output, target)
    loss.backward(retain_graph = True)
    # w_optimizer.step()

    # accuracy calculation
    epoch_train_loss += loss.data.item()
    
    delta_ws = []

    for p_idx, param in enumerate(w_model.parameters()):
        grad_tmp = param.grad.data.clone().detach()
        delta_ws.append(grad_tmp)

    return delta_ws, epoch_train_loss