import torch
import copy
import time
import numpy as np
from new_myutils import adjust_lr_lambda, AverageMeter, accuracy, save_checkpoint_epoch
from myDataLoader import get_train_val_test_loader, get_train_val_test_feature_loader
import torch.nn as nn
from balgs import validate, adjust_curlr_beta, validate_cifar_val_loader
import wandb

# Import only relevant methods from New_utils
from New_utils import cifar_iid, DatasetSplit, average_weights, update_ykt_batch, update_x_k_v3

# D_DRO Algorithm 3 in FL setting
def DDRO_Algorithm3(args, model_new, results):
    ## Initialize Weights & Biases (wandb) logging
    wandb.init(config=args, project="P", entity="aditi1_cse-wayne-state-university")

    ## Fix random seeds for reproducibility
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    ## Get the dataset and DataLoader
    train_data, val_data = get_train_val_test_loader(args, None, return_dataset=True)
    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)

    ## Split dataset for clients (users) in the FL setting
    idxs_users = list(range(args.num_users))
    dict_users = cifar_iid(dataset=train_data, num_users=args.num_users)

    ## Initialize variables for FL training
    best_acc1 = 0
    start_time = time.time()
    CE_criterion = nn.CrossEntropyLoss()

    ## Pre-evaluation of test accuracy using pretrained model (if available)
    if test_loader is not None:
        test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, model_new, CE_criterion, 0)
        print(f'Pretrained Test Accuracy: {test_prec1:.3f}')

    ## Initialize global model and weights
    global_model = model_new
    global_weights = global_model.state_dict()

    ## Record local weights for each user
    local_weights_pre = [copy.deepcopy(global_weights) for _ in range(len(idxs_users))]
    local_weights_cur = [copy.deepcopy(global_weights) for _ in range(len(idxs_users))]

    ## Select the local update method
    update_x_k = update_x_k_v3  # Only v3 is implemented

    ## FL main loop (Algorithm 3)
    for epoch in range(args.resumed_epoch, args.epochs):
        print(f'\n | Global Training Round : {epoch+1} |\n')

        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        # Adjust learning rate and beta
        adjust_curlr_beta(epoch, args, optimizer=None)
        global_model.train()

        ## Local updates (Algorithm 3 logic for each client)
        for idx in idxs_users:
            model_pre = copy.deepcopy(global_model)
            model_pre.load_state_dict(local_weights_pre[idx])
            model_cur = copy.deepcopy(global_model)
            model_cur.load_state_dict(local_weights_cur[idx])

            # Get client-specific data
            train_data_idx = DatasetSplit(train_data, dict_users[idx])
            train_loader_idx = torch.utils.data.DataLoader(train_data_idx, batch_size=args.local_bs, shuffle=True, num_workers=args.works, pin_memory=True)

            # Update ykt with a batch or epoch of data (Algorithm 3 logic)
            ykt_idx = update_ykt_batch(model_pre=model_pre, model_cur=model_cur, global_round=epoch, 
                                       ykt=y_kt[idx], trainloader=train_loader_idx, 
                                       beta=args.curbeta, lmbda=args.lamda, local_ep=args.local_ep)
            y_kt[idx] = ykt_idx

        # Compute the average y_t from all clients (FL aggregation)
        y_t = sum(y_kt) / len(y_kt)

        ## Update local weights (each client)
        for idx in idxs_users:
            model_cur = copy.deepcopy(global_model)
            model_cur.load_state_dict(local_weights_cur[idx])

            # Get client-specific data
            train_data_idx = DatasetSplit(train_data, dict_users[idx])
            train_loader_idx = torch.utils.data.DataLoader(train_data_idx, batch_size=args.local_bs, shuffle=True, num_workers=args.works, pin_memory=True)

            # Update x_k (Algorithm 3 local update)
            local_weights_cur[idx] = update_x_k(model_cur=model_cur, y_t=y_t, trainloader=train_loader_idx, 
                                                global_round=epoch, lmbda=args.lamda, eta=args.curlr, local_ep=args.local_ep, args=args)

        ## Global update (Server-side aggregation)
        if (epoch + 1) % args.I == 0:  # Global aggregation every 'I' rounds
            global_weights = average_weights(local_weights_cur)
            global_model.load_state_dict(global_weights)

        ## Validation and logging
        if args.epochs > 10:
            train_loss, train_prec1, _, _ = validate(args, train_loader, global_model, CE_criterion, epoch)
            val_loss, val_prec1, _, _ = validate_cifar_val_loader(args, val_loader, global_model, CE_criterion, epoch)
            test_loss, test_prec1, _, _ = validate(args, test_loader, global_model, CE_criterion, epoch)

            ## Save results and log in wandb
            results.add(epoch=epoch, val_loss=val_loss, train_prec1=train_prec1, val_prec1=val_prec1, test_prec1=test_prec1, overall_running_time=(time.time() - start_time)//60)
            results.save()

            print(f"Epoch {epoch}/{args.epochs} | Train Acc@1 {train_prec1:.3f} | Val Acc@1 {val_prec1:.3f} | Test Acc@1 {test_prec1:.3f}")
            wandb.log({"train_acc1": train_prec1, "val_acc1": val_prec1, "test_acc1": test_prec1}, step=epoch)

        # Save the best model
        is_best = test_prec1 > best_acc1
        best_acc1 = max(test_prec1, best_acc1)
        save_checkpoint_epoch({'epoch': epoch + 1, 'state_dict': global_model.state_dict(), 'best_acc1': best_acc1}, is_best)
