import torch, copy, time
import numpy as np
from myutils import adjust_lr_lambda, AverageMeter, accuracy, save_checkpoint_epoch
from myDataLoader import get_train_val_test_loader, get_train_val_test_feature_loader
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from torch.autograd import Variable
from torch.nn.parallel._functions import Broadcast
import torch.cuda.comm as comm
from torch.autograd import Function
from balgs import validate, adjust_curlr_beta, validate_cifar_val_loader
import wandb

# for FL
from ddro_utils import cifar_iid, DatasetSplit, average_weights
from ddro_utils import update_ykt_batch, update_ykt
from ddro_utils import update_x_k_v0, update_x_k_v1, update_x_k_v2, update_x_k_v3, update_x_k_v4

# our D_DRO algorithm in FL setting
def DS_FedDRO_M(args, model_new, results):
    ## initialize weights and bias logging
    wandb.init(config=args, project="avg", entity="hq1351-wayne-state-universit")
    
    ## fix random seeds
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    
    ## get the dataset (cifar10, cifar100) and test_loader
    train_data, val_data = get_train_val_test_loader(args, None, return_dataset=True)
    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)
    
    ## split dataset for different clients (users) in FL setting
    idxs_users = list(range(args.num_users))
    # split the whole train dataset for each user
    dict_users = cifar_iid(dataset=train_data, num_users=args.num_users)
    
    ## initialize for FL training
    best_acc1 = 0
    start_time = time.time()
    train_loss, train_prec1, train_prec5, test_prec1, test_prec5 = 0, 0, 0, 0, 0
    CE_criterion = nn.CrossEntropyLoss()
    
    if test_loader is not None:
        test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, model_new, CE_criterion, 0)
        print('>>>>>>>> Pretrained test_prec1 {:.3f}'.format(test_prec1))
    
    y_kt = [0.0 for i in idxs_users]
    global_model = model_new
    global_weights = global_model.state_dict()
    
    # record weights for different users
    local_weights_pre = [copy.deepcopy(global_weights) for i in range(len(idxs_users))]
    local_weights_cur = [copy.deepcopy(global_weights) for i in range(len(idxs_users))]
    
    tau = 0
    
    # Select the appropriate x_k update version
    if args.update_x_k == 'v1':
        update_x_k = update_x_k_v1
    elif args.update_x_k == 'v2':
        update_x_k = update_x_k_v2
    elif args.update_x_k == 'v3':
        update_x_k = update_x_k_v3
    elif args.update_x_k == 'v4':
        update_x_k = update_x_k_v4
    elif args.update_x_k == 'v0':
        update_x_k = update_x_k_v0
    else:
        update_x_k = update_x_k_v1
    
    # FL outer loop, rounds
    for epoch in range(args.resumed_epoch, args.epochs):
        print(f'\n | Global Training Round : {epoch+1} |\n')
        
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        
        # Adjust learning rate
        adjust_curlr_beta(epoch, args, optimizer=None)
        print(f"Learning rate for epoch {epoch}: {args.curlr}")
        
        global_model.train()

        # Local updates
        for idx in idxs_users:
            # Local model copies
            model_pre = copy.deepcopy(global_model)
            model_pre.load_state_dict(local_weights_pre[idx])
            model_cur = copy.deepcopy(global_model)
            model_cur.load_state_dict(local_weights_cur[idx])
            
            # Data for client idx
            train_data_idx = DatasetSplit(train_data, dict_users[idx])
            train_loader_idx = torch.utils.data.DataLoader(train_data_idx, batch_size=args.local_bs, shuffle=True, num_workers=args.works, pin_memory=True)
            
            # Update ykt with a batch of train data
            ykt_idx = update_ykt_batch(model_pre=model_pre, model_cur=model_cur,
                                       global_round=epoch, ykt=y_kt[idx], trainloader=train_loader_idx, 
                                       beta=args.curbeta, lmbda=args.lamda, local_ep=1)
            y_kt[idx] = ykt_idx

            # Log and check gradients
            for name, param in model_cur.named_parameters():
                if param.grad is not None:
                    print(f"{name} gradient norm: {param.grad.norm()}")
            
        ## update weights (in local)
        local_weights_pre = local_weights_cur
        
        # Calculate grad_phi_k and update local models
        for idx in idxs_users:
            y_t = sum(y_kt) / len(y_kt)
            model_cur = copy.deepcopy(global_model)
            model_cur.load_state_dict(local_weights_cur[idx])
            
            # Data for client idx
            train_data_idx = DatasetSplit(train_data, dict_users[idx])
            train_loader_idx = torch.utils.data.DataLoader(train_data_idx, batch_size=args.local_bs, shuffle=True, num_workers=args.works, pin_memory=True)
            
            x_k_new = update_x_k(model_cur=model_cur, y_t=y_kt[idx], trainloader=train_loader_idx,
                                 global_round=epoch, lmbda=args.lamda, eta=args.curlr, local_ep=args.local_ep, args=args)
            
            # Record the updated local weights
            local_weights_cur[idx] = copy.deepcopy(x_k_new)
            
        ## log average local train accuracy
        
        if (epoch + 1) % args.I == 0:
            
            y_t = sum(y_kt) / len(y_kt)

            # Aggregation at the server (Step 8)
            bar_x_tau = global_model.state_dict()
            bar_y_tau = y_t

            # Aggregation Step (Equation 8)
            for key in bar_x_tau.keys():
                bar_x_tau[key] = bar_x_tau[key] - args.gamma_x * sum((bar_x_tau[key] - local_weights_cur[idx][key]) for idx in range(len(idxs_users))) / len(idxs_users)

            global_model.load_state_dict(bar_x_tau)
            bar_y_tau = bar_y_tau - args.gamma_y * sum((bar_y_tau - y_kt[idx]) for idx in range(len(idxs_users))) / len(idxs_users)
            y_kt[idx] = bar_y_tau  # Synchronize all clients' y_k with the global y_t

            # Increment tau after aggregation step
            tau += 1

        # Validation, Logging, and Output
        if args.epochs > 10:
            # Validate model on training and validation sets
            train_loss, train_prec1, train_prec5, train_auc_score = validate(args, train_loader, global_model, CE_criterion, epoch)
            
            val_loss, val_prec1, val_prec5, val_auc_score = validate(args, val_loader, global_model, CE_criterion, epoch) if 'cifar' in args.dataset else validate(args, val_loader, global_model, CE_criterion, epoch)
            
            if test_loader is not None:
                test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, global_model, CE_criterion, epoch)
            else:
                test_loss, test_prec1, test_prec5, test_auc_score = val_loss, val_prec1, val_prec5, val_auc_score

            overall_running_time = (time.time() - start_time) // 60
            best_acc1 = max(best_acc1, val_prec1 if test_loader is None else test_prec1)

            # Store and save the results
            results.add(epoch=epoch, val_loss=val_loss,
                        train_prec1=train_prec1, val_prec1=val_prec1, test_prec1=test_prec1 if test_loader is not None else val_prec1,
                        train_prec5=train_prec5, val_prec5=val_prec5, test_prec5=test_prec5 if test_loader is not None else val_prec5,
                        overall_running_time=overall_running_time)
            results.save()

            # Output to console
            output = (f'Train: [{epoch}/{args.epochs}], lr: {args.curlr:.5f}\t'
                      f'Train Loss {train_loss:.4f} Val Loss {val_loss:.4f}\t'
                      f'Train Prec@1 {train_prec1:.3f} Val Prec@1 {val_prec1:.3f} Test Prec@1 {test_prec1:.3f}\t'
                      f'Train Prec@5 {train_prec5:.3f} Val Prec@5 {val_prec5:.3f} Test Prec@5 {test_prec5:.3f}')
            print(output)
            print(f"Lambda Variable value: {args.lamda}")
            print(f'Total running time: {overall_running_time:.3f} minutes')

            # Log metrics to WandB
            wandb.log({"lr": args.curlr, 'Optimized Lambda Variable': args.lamda}, step=epoch)
            wandb.log({"train loss": train_loss}, step=epoch)
            wandb.log({"train loss": train_loss, 'train acc1': train_prec1, 'train acc5': train_prec5}, step=epoch)
            wandb.log({"best test acc": best_acc1, 'beta': args.curbeta}, step=epoch)
            wandb.log({'training time': overall_running_time}, step=epoch)

            if test_loader is not None:
                wandb.log({"test loss": test_loss, 'test acc1': test_prec1, 'test acc5': test_prec5}, step=epoch)
            else:
                wandb.log({"test loss": val_loss, 'test acc1': val_prec1, 'test acc5': val_prec5}, step=epoch)

    print(f'Total number of rounds: {tau}')
