import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import numpy as np
import time
from myutils import AverageMeter, accuracy, save_checkpoint_epoch, adjust_lr_lambda
from myDataLoader import get_train_val_test_loader
import wandb
from ddro_utils import cifar_iid, DatasetSplit, average_weights, update_ykt_batch, update_x_k

# Adjust learning rate and beta function
def adjust_curlr_beta(epoch, args, optimizer=None):
    """
    Adjusts current learning rate and beta parameter based on the epoch.
    Args:
        epoch (int): Current epoch number.
        args: Argument parser containing the necessary hyperparameters.
        optimizer: Optimizer to adjust (if applicable).
    """
    if epoch < args.epochs * 0.5:
        curlr = args.lr
        curbeta = args.beta
    elif epoch < args.epochs * 0.75:
        curlr = args.lr * 0.1
        curbeta = args.beta * 0.9
    else:
        curlr = args.lr * 0.01
        curbeta = args.beta * 0.8

    if optimizer:
        for param_group in optimizer.param_groups:
            param_group['lr'] = curlr

    args.curlr = curlr
    args.curbeta = curbeta

    print(f"Epoch {epoch}: Adjusted curlr to {curlr}, curbeta to {curbeta}")


# Validate function definition
def validate(args, data_loader, model, criterion, epoch):
    model.eval()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            # Unpack accordingly depending on how many values the dataloader returns
            if len(batch) == 3:
                _, input, target = batch  # Ignoring the index if present
            else:
                input, target = batch  # If it returns only input and target

            input = input.to(args.device)
            target = target.to(args.device)

            # Compute output (model might return a tuple, so we extract the first value)
            output = model(input)
            if isinstance(output, tuple):
                output = output[0]  # Take the first element if output is a tuple

            # Compute loss
            loss = criterion(output, target)

            # Measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

    print(f'Epoch {epoch}: Validation Loss {losses.avg}, Top-1 Accuracy {top1.avg}, Top-5 Accuracy {top5.avg}')
    
    return losses.avg, top1.avg, top5.avg, None


# DS-FedDRO implementation
def DS_FedDRO(args, model_new, results):
    # Initialize WandB for logging
    wandb.init(config=args, project="P", entity="aditi1_cse-wayne-state-university")

    # Fix random seeds
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Add device to args for later use
    args.device = device

    # Ensure lamda, im_ratio, local_opt are in args
    if not hasattr(args, 'lamda'):
        args.lamda = 1.0
    if not hasattr(args, 'im_ratio'):
        args.im_ratio = 1.0
    if not hasattr(args, 'local_opt'):
        args.local_opt = 'sgd'

    # Get dataset and initialize loaders
    train_data, val_data = get_train_val_test_loader(args, None, return_dataset=True)
    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)

    # Debugging output to verify loader initialization
    print(f"Train loader: {train_loader}")
    print(f"Validation loader: {val_loader}")
    print(f"Test loader: {test_loader}")

    # Fallback if test_loader is None, use validation loader for evaluation
    if test_loader is None:
        print("Test loader is None, using validation loader for evaluation.")
        test_loader = val_loader

    # Split the dataset among clients
    idxs_users = list(range(args.num_users))
    dict_users = cifar_iid(dataset=train_data, num_users=args.num_users)

    # Initialize models and weights
    global_model = model_new.to(device)
    global_weights = global_model.state_dict()
    local_weights_pre = [copy.deepcopy(global_weights) for _ in range(len(idxs_users))]
    local_weights_cur = [copy.deepcopy(global_weights) for _ in range(len(idxs_users))]

    y_kt = [0.0 for _ in idxs_users]  # Local momentum initialization for each client
    y_t = 0.0  # Global momentum initialization

    CE_criterion = nn.CrossEntropyLoss()
    best_acc1 = 0
    start_time = time.time()

    # Initial evaluation if test_loader is not None
    if test_loader is not None:
        test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, global_model, CE_criterion, 0)
        print(f'Pretrained test_prec1: {test_prec1:.3f}')

    # Main federated learning loop
    for epoch in range(args.resumed_epoch, args.epochs):
        print(f'\n | Global Training Round : {epoch + 1} |\n')

        adjust_curlr_beta(epoch, args, optimizer=None)
        global_model.train()

        # Local updates for each client
        for idx in idxs_users:
            model_pre = copy.deepcopy(global_model).to(device)
            model_cur = copy.deepcopy(global_model).to(device)

            model_pre.load_state_dict(local_weights_pre[idx])
            model_cur.load_state_dict(local_weights_cur[idx])

            # Data for client idx
            train_data_idx = DatasetSplit(train_data, dict_users[idx])
            train_loader_idx = torch.utils.data.DataLoader(train_data_idx, batch_size=args.local_bs, shuffle=True,
                                                           num_workers=args.works, pin_memory=True)

            # Update y_kt for momentum update
            ykt_idx = update_ykt_batch(model_pre=model_pre, model_cur=model_cur,
                                       global_round=epoch, ykt=y_kt[idx],
                                       trainloader=train_loader_idx, beta=args.curbeta, lmbda=args.lamda, local_ep=1)
            y_kt[idx] = ykt_idx

        # Global momentum update
        y_t = sum(y_kt) / len(y_kt)

        # Local weight update with dual learning rates
        local_weights_pre = local_weights_cur
        for idx in idxs_users:
            model_cur = copy.deepcopy(global_model).to(device)

            # Data for client idx
            train_data_idx = DatasetSplit(train_data, dict_users[idx])
            train_loader_idx = torch.utils.data.DataLoader(train_data_idx, batch_size=args.local_bs, shuffle=True,
                                                           num_workers=args.works, pin_memory=True)

            # Perform model update using momentum and dual learning rates
            x_k_new = update_x_k(model_cur=model_cur, y_t=y_t, trainloader=train_loader_idx,
                                 global_round=epoch, lmbda=args.lamda, eta=args.curlr, local_ep=args.local_ep, args=args)

            local_weights_cur[idx] = copy.deepcopy(x_k_new)

        # Update global model (server)
        global_weights = average_weights(local_weights_cur)
        global_model.load_state_dict(global_weights)

        # Evaluation
        global_model.eval()
        if test_loader is not None:
            test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, global_model, CE_criterion, epoch)
            print(f'Test Prec@1: {test_prec1:.3f}, Test Prec@5: {test_prec5:.3f}')

            # Log results using WandB
            wandb.log({"epoch": epoch, "test_acc1": test_prec1, "test_acc5": test_prec5}, step=epoch)

            # Save best model
            save_checkpoint_epoch({'epoch': epoch, 'state_dict': global_model.state_dict(), 'best_prec1': best_acc1}, 
                                  is_best=(test_prec1 == best_acc1), path=args.results_dir)
        else:
            print("Test loader is None, skipping evaluation and logging.")

    return global_model, best_acc1
