import argparse
import os
import time
from datetime import datetime
import torch
import torch.backends.cudnn as cudnn
from mydataset import get_imbalanced_dataset, get_num_classes
import models
from myutils import ResultsLog, model_resume
from preprocess import get_transform_medium_scale_data
from balgs import PDSGDDRO, FastDRO, MBSGD
from qalgs import RECOVER, SCCMA, ACCSCCMA
from ddro import DDRO  # Import DDRO from ddro file
from ddro_updated import DS_FedDRO  # Import DS_FedDRO for FS-DRO

# Argument Parser
parser = argparse.ArgumentParser(description="Pytorch D_DRO Training")
parser.add_argument('--results_dir', metavar="RESULTS_DIR", default='./TrainingResults', help='results dir')
parser.add_argument('--saveFolder', metavar='SAVE', default='cifar10/MBSGD/Wm_FastDRO_wlr_0.1_rho_1_beta_0.6_plr_1e-4_lambda1_50_batch_128_epochs_120_model_resnet20_DR_10_Repeats_3', help='save folder')
parser.add_argument('--res_filename', default='cifar10_MBSGD_wlr_0.1_rho_1_beta_0.6_plr_1e-4_lambda0__batch_128_epochs_120_model_resnet20_DR_10_Repeats_3_class_tau_0.05', type=str, help='results file name')
parser.add_argument('--dataset', metavar='DATASET', default='cifar10', help='dataset name or folder')
parser.add_argument('--model', metavar='MODEL', default='resnet20', help='model architecture')
parser.add_argument('--type', default='torch.FloatTensor', help='types of tensor - e.g torch.FloatTensor for CPU')
parser.add_argument('--gpus', default='', help='gpus used for training - leave empty to use CPU')
parser.add_argument('--workers', default='8', type=int, metavar='N', help='number of data loading workers (default:256)')
parser.add_argument('--batch-size', default=128, type=int, metavar='N', help='mini-batch size (default:256)')
parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT', help='optimizer function used')
parser.add_argument('--momentum', default=0.9, type=float, metavar="M", help="momentum parameter")
parser.add_argument('--scale_size', default=32, type=int, help='image scale size for data preprocessing')
parser.add_argument('--input_size', default=32, type=int, help='the size of image. e.g. 32 for cifar10, 224 for imagenet')
parser.add_argument('--works', default=8, type=int, help='number of threads used for loading data')
parser.add_argument('--weight_decay', default=2e-4, type=float, help='weight decay parameters')
parser.add_argument('--print_freq', '-p', default=50, type=int, help='print frequency (default:50)')
parser.add_argument('--epochs', default=120, type=int, help='number of total epochs')
parser.add_argument('--lr', default=0.1, type=float, metavar='WLR', help='initial learning rate of w')
parser.add_argument('--plr', default=1e-4, type=float, help='Dual Variable P')
parser.add_argument('--rho', default=1, type=float, help='Constraint of DRO: rho')
parser.add_argument('--lamda', default=1.0, type=float, help='Lambda parameter for regularization')  # Add lamda
parser.add_argument('--resume', default=False, type=str, help='Path to the checkpoint file to resume training from')
parser.add_argument('--resumed_epoch', default=0, type=int, help="continuing training from a saved checkpoint")
parser.add_argument('--num_classes', default=10, type=int, help="classes of different datasets")
parser.add_argument('--num_users', default=1, type=int, help='the number of clients to be used')
parser.add_argument('--local_ep', default=1, type=int, help='local_epoch for updates')
parser.add_argument('--update_x_k', default='v1', type=str, help='local update algorithm type')
parser.add_argument('--curlr', default=0.1, type=float, help='current learning rate')
parser.add_argument('--curbeta', default=0.1, type=float, help='momentum parameter')
parser.add_argument('--random_seed', default=123, type=int, help='random seed for reproducibility')
parser.add_argument('--im_ratio', default=1.0, type=float, help='imbalance ratio for the dataset')
parser.add_argument('--beta', default=0.6, type=float, help='beta parameter for the optimizer')
parser.add_argument('--local_bs', default=64, type=int, help='Local batch size for each client')
parser.add_argument('--local_opt', default='sgd', type=str, help='Optimizer for local updates, e.g., sgd or adam')
parser.add_argument('--I', default=10, type=int, help='Number of local updates before global update (I)')


parser.add_argument('--alg', default='DS_FedDRO', type=str, choices=['DDRO', 'DS_FedDRO', 'RECOVER', 'FastDRO', 'PDSGD', 'ACCSCCMA', 'SCCMA', 'MBSGD'], help='The choice of algorithms')


def main():
    torch.manual_seed(123)
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()
    args.start_training_time = time.time()

    if args.saveFolder == '':
        args.saveFolder = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    args.results_dir = os.path.join(args.results_dir, args.saveFolder)
    if not os.path.exists(args.results_dir):
        os.makedirs(args.results_dir)
    results_file = os.path.join(args.results_dir, args.res_filename + '_results.csv')
    results = ResultsLog(results_file)

    # Check if GPU is available, otherwise, use CPU
    device = torch.device("cuda" if torch.cuda.is_available() and args.gpus else "cpu")
    if device.type == 'cuda':
        print("CUDA is available, using GPU(s)")
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        args.gpus = [int(i) for i in args.gpus.split(',')]
        cudnn.benchmark = True
    else:
        print("CUDA not available or not requested, using CPU")
        args.gpus = None
        args.type = 'torch.FloatTensor'

    args.num_classes = get_num_classes(args)
    model = models.__dict__[args.model]
    model_cur = model(num_classes=args.num_classes)
    model_prev = model(num_classes=args.num_classes)

    # Move models to the selected device
    model_cur = model_cur.to(device)
    model_prev = model_prev.to(device)

    if args.gpus and len(args.gpus) >= 1:
        model_cur = torch.nn.DataParallel(model_cur)
        model_prev = torch.nn.DataParallel(model_prev)

    if args.resume:
        if os.path.isfile(args.resume):
            model_resume(args, args.resume, model_cur)
        else:
            print(f"=> no checkpoint found at '{args.resume}'")

    # Adjust imbalanced ratio based on dataset
    if args.dataset == 'cifar10':
        args.im_ratio = 0.02
    elif args.dataset == 'cifar100':
        args.im_ratio = 0.2

    # Algorithm selection
    if args.alg == "PDSGD":
        PDSGDDRO(args, model_cur, results)
    elif args.alg == 'SCCMA':
        SCCMA(args, model_cur, results)
    elif args.alg == 'MBSGD':
        MBSGD(args, model_cur, results)
    elif args.alg == 'RECOVER':
        RECOVER(args, model_cur, results)
    elif args.alg == 'FastDRO':
        FastDRO(args, model_cur, results)
    elif args.alg == 'ACCSCCMA':
        ACCSCCMA(args, model_cur, model_prev, results)
    elif args.alg == 'DDRO':
        DDRO(args, model_cur, results)
    elif args.alg == 'DS_FedDRO':  # Algorithm 3 from FedDRO
        DS_FedDRO(args, model_cur, results)
    else:
        print("Invalid algorithm choice")


if __name__ == '__main__':
    main()
