from __future__ import annotations
import copy
import torch
import torch.cuda as cuda
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data as data
import utils.data_utils as data_utils

def local_update(
    user_tr_data_tensor: torch.Tensor, user_tr_label_tensor: torch.Tensor,
    global_model: nn.Module, communication_round: int,
    args,
):
    local_model = copy.deepcopy(global_model)
    with torch.no_grad():
        initial_state = copy.deepcopy(local_model.state_dict())

    dataset = data_utils.TensorDataset(user_tr_data_tensor, user_tr_label_tensor)
    train_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)

    optimizer = optim.SGD(local_model.parameters(), lr=args.local_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    if args.scheduler == 'step':
        scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=args.step_size, gamma=args.gamma)
    elif args.scheduler == 'exponential':
        scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.gamma)
    elif args.scheduler == 'multi_step':
        scheduler = lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=args.milestones, gamma=args.gamma)
    
    if args.scheduler is not None:
        scheduler.step(communication_round)

    local_model.train()
    for _ in range(args.n_epochs):
        for _, (inputs, targets) in enumerate(train_loader):
            batch_size = len(inputs)
            if args.arch in ['resnet18'] and batch_size <= 1:
                continue
            
            optimizer.zero_grad()
            inputs, targets = inputs.cuda(), targets.cuda()

            outputs = local_model(inputs)
            reg_term = torch.zeros(1).cuda()
            for name, val in local_model.named_parameters():
                reg_term += (val - initial_state[name].cuda()).square().sum()
            loss = F.cross_entropy(outputs, targets)

            local_model.zero_grad()
            loss.backward()

            if args.grad_clip:
                nn.utils.clip_grad.clip_grad_norm_(local_model.parameters(), args.max_norm)
            optimizer.step()

    with torch.no_grad():
        current_state = local_model.state_dict()
        update = {name: current_state[name] - initial_state[name] for name in current_state}

    n_samples = len(user_tr_data_tensor)
    return update, n_samples

def local_update_batch(
	global_model: nn.Module, communication_round: int,
	client_datas: list[torch.Tensor], client_labels: list[torch.Tensor],
	args, 
) -> tuple[list[dict[str, torch.Tensor()]], list[int]]:
	updates = []
	n_samples_lt = []
	n_clients = len(client_datas)
	for client_idx in range(n_clients):
		update, n_samples = local_update(
			client_datas[client_idx], client_labels[client_idx],
			global_model, communication_round,
			args)
		updates.append(update)
		n_samples_lt.append(n_samples)
	return updates, n_samples_lt