import torch, copy, time
import numpy as np
from myutils import adjust_lr_lambda, AverageMeter, accuracy, save_checkpoint_epoch
from myDataLoader import get_train_val_test_loader, get_train_val_test_feature_loader
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from torch.autograd import Variable
from torch.nn.parallel._functions import Broadcast
import torch.cuda.comm as comm
from torch.autograd import Function
from balgs import validate, adjust_curlr_beta, validate_cifar_val_loader
import wandb

# for FL
from ddro_utils import cifar_iid, DatasetSplit, average_weights
from ddro_utils import update_ykt_batch, update_ykt
from ddro_utils import update_x_k_v0, update_x_k_v1, update_x_k_v2, update_x_k_v3, update_x_k_v4



def DS_FedDRO(args, model_new, results):
    # Initialize wandb logging
    wandb.init(config=args, project="P", entity="aditi1_cse-wayne-state-university")

    # Fix random seeds for reproducibility
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    # Load dataset (train, val, test)
    train_data, val_data = get_train_val_test_loader(args, None, return_dataset=True)
    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)

    # Split dataset for FL across clients
    idxs_users = list(range(args.num_users))
    dict_users = cifar_iid(dataset=train_data, num_users=args.num_users)

    # Initialize FL training
    best_acc1 = 0
    start_time = time.time()
    CE_criterion = nn.CrossEntropyLoss()

    # Pretrained model evaluation on test set (if available)
    if test_loader is not None:
        test_loss, test_prec1, test_prec5, _ = validate(args, test_loader, model_new, CE_criterion, 0)
        print(f'Pretrained test accuracy: {test_prec1:.3f}')

    # Initialize global and local models
    global_model = model_new
    global_weights = global_model.state_dict()
    local_weights_pre = [copy.deepcopy(global_weights) for _ in range(len(idxs_users))]
    local_weights_cur = [copy.deepcopy(global_weights) for _ in range(len(idxs_users))]
    y_kt = [0.0 for _ in idxs_users]

    # Communication round counter
    com = 0  # Aggregation step counter
    tau = 0  # Aggregation step counter

    # Define T based on the value of I (T = I * 120)
    T = args.I * 120
    total_epochs = T * args.I

    # Select appropriate x_k update version based on args
    update_x_k = globals().get(f'update_x_k_{args.update_x_k}', update_x_k_v1)

    # Federated Learning loop across epochs (T * I)
    for epoch in range(args.resumed_epoch, total_epochs):
        print(f'\n | Global Training Round : {epoch + 1} |')

        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        # Adjust learning rate
        adjust_curlr_beta(com, args, optimizer=None)
        print(f"Learning rate for communication round {com}: {args.curlr}")

        global_model.train()

        # Local model updates for each client
        for idx in idxs_users:
            model_pre = copy.deepcopy(global_model)
            model_pre.load_state_dict(local_weights_pre[idx])
            model_cur = copy.deepcopy(global_model)
            model_cur.load_state_dict(local_weights_cur[idx])

            # Load client-specific data
            train_data_idx = DatasetSplit(train_data, dict_users[idx])
            train_loader_idx = torch.utils.data.DataLoader(train_data_idx, batch_size=args.local_bs, shuffle=True, num_workers=args.works, pin_memory=True)

            # Update ykt with local training data
            ykt_idx = update_ykt_batch(model_pre=model_pre, model_cur=model_cur, global_round=epoch, ykt=y_kt[idx], trainloader=train_loader_idx, beta=args.curbeta, lmbda=args.lamda, local_ep=1)
            y_kt[idx] = ykt_idx

        # Update local weights after training
        local_weights_pre = local_weights_cur

        # Update global model using local models' updates
        for idx in idxs_users:
            y_t = sum(y_kt) / len(y_kt)
            model_cur = copy.deepcopy(global_model)
            model_cur.load_state_dict(local_weights_cur[idx])

            # Load client-specific data
            train_data_idx = DatasetSplit(train_data, dict_users[idx])
            train_loader_idx = torch.utils.data.DataLoader(train_data_idx, batch_size=args.local_bs, shuffle=True, num_workers=args.works, pin_memory=True)

            x_k_new = update_x_k(model_cur=model_cur, y_t=y_kt[idx], trainloader=train_loader_idx, global_round=epoch, lmbda=args.lamda, eta=args.curlr, local_ep=args.local_ep, args=args)
            local_weights_cur[idx] = copy.deepcopy(x_k_new)

        # Perform server-side aggregation every I iterations (T * I, save and plot when (T % I) == 0)
        if (epoch + 1) % args.I == 0:
            com += 1

            # Aggregate updates
            y_t = sum(y_kt) / len(y_kt)
            bar_x_tau = global_model.state_dict()
            bar_y_tau = y_t

            for key in bar_x_tau.keys():
                bar_x_tau[key] -= args.gamma_x * sum((bar_x_tau[key] - local_weights_cur[idx][key]) for idx in range(len(idxs_users))) / len(idxs_users)

            global_model.load_state_dict(bar_x_tau)
            bar_y_tau -= args.gamma_y * sum((bar_y_tau - y_kt[idx]) for idx in range(len(idxs_users))) / len(idxs_users)
            y_kt = [bar_y_tau] * len(idxs_users)

            tau += 1

        # Validation, Logging, and Output (adjusted to use 'com' instead of 'epoch')
        if args.epochs > 10:
            # Validate model on training, validation, and (optionally) test sets
            train_loss, train_prec1, train_prec5, train_auc_score = validate(args, train_loader, global_model, CE_criterion, com)
            val_loss, val_prec1, val_prec5, val_auc_score = validate(args, val_loader, global_model, CE_criterion, com)

            if test_loader is not None:
                test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, global_model, CE_criterion, com)
            else:
                test_loss, test_prec1, test_prec5, test_auc_score = val_loss, val_prec1, val_prec5, val_auc_score

            overall_running_time = (time.time() - start_time) // 60
            best_acc1 = max(best_acc1, val_prec1 if test_loader is None else test_prec1)

            # Store and save the results (updated to use 'com' instead of 'epoch')
            results.add(
                epoch=com,  # Communication round instead of epoch
                val_loss=val_loss,
                train_prec1=train_prec1, val_prec1=val_prec1,
                test_prec1=test_prec1 if test_loader is not None else val_prec1,
                train_prec5=train_prec5, val_prec5=val_prec5,
                test_prec5=test_prec5 if test_loader is not None else val_prec5,
                overall_running_time=overall_running_time
            )
            results.save()

            # Output to console (now referring to communication round instead of epoch)
            output = (f'Train: [{com}/{com}], lr: {args.curlr:.5f}\t'
                      f'Train Loss {train_loss:.4f} Val Loss {val_loss:.4f}\t'
                      f'Train Prec@1 {train_prec1:.3f} Val Prec@1 {val_prec1:.3f} Test Prec@1 {test_prec1:.3f}\t'
                      f'Train Prec@5 {train_prec5:.3f} Val Prec@5 {val_prec5:.3f} Test Prec@5 {test_prec5:.3f}')
            print(output)
            print(f"Lambda Variable value: {args.lamda}")
            print(f'Total running time: {overall_running_time:.3f} minutes')

            # Consolidated logging to WandB (now tied to 'com')
            wandb.log({
                "lr": args.curlr,
                "Optimized Lambda Variable": args.lamda,
                "train loss": train_loss,
                "train acc1": train_prec1,
                "train acc5": train_prec5,
                "best test acc": best_acc1,
                "beta": args.curbeta,
                "training time": overall_running_time,
                "test loss": test_loss if test_loader is not None else val_loss,
                "test acc1": test_prec1 if test_loader is not None else val_prec1,
                "test acc5": test_prec5 if test_loader is not None else val_prec5
            }, step=com)

    print(f'Total number of rounds: {tau}')
