import torch
import math
from KFOptimizer import wrap_optimizer
from train_utils import noisy_train, train, test
from init_utils import base_parse_args, task_init, logger_init
from fastDP import PrivacyEngine
from AdamBC import AdamBC
import argparse
import warnings
import gc
import os

if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    parser = argparse.ArgumentParser(description='LP DPSGD')
    parser = base_parse_args(parser)
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    train_dl, test_dl, model, device, sample_size, acc_step, noise = task_init(args)
    print(f"[run_KFSGD] device={device}, cuda_available={torch.cuda.is_available()}, cuda_device={torch.cuda.current_device() if torch.cuda.is_available() else 'N/A'}")

    log_file = logger_init(args, noise, sample_size//args.mnbs,type=args.log_type)
    if args.data == 'imgnet1k':
        train_dl_1 = train_dl
        test_dl_1 = test_dl

    use_manual_noise = not args.clipping and noise>0
    if use_manual_noise:
        noise = noise/args.mnbs
        args.lr = args.lr/acc_step
        print('use manual noise')

    if args.algo == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum = 0)
    elif args.algo == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    elif args.algo == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.04)
    elif args.algo == 'adambc':
        optimizer = AdamBC(model.parameters(), lr=args.lr, dp_batch_size=args.bs, dp_l2_norm_clip=1, dp_noise_multiplier=noise, eps_root=1e-8)
    elif args.algo == 'Fiber':
        from Fiber import Fiber
        gamma = args.gamma 
        kappa = args.kappa 
        omega = args.omega 
        batch_size = args.bs
        optimizer = Fiber(model.parameters(), lr=args.lr,weight_decay=0.00, kappa=kappa, gamma=gamma,omega=omega,use_filter_aware_adambc=True, batch_size=batch_size,dp_noise_std=noise)  
    else:
        print(args.algo)
        raise RuntimeError("Unknown Algorithm!")
    
    start = 0
    
    if args.load_path is not None:
        print("loading optimizer")
        checkpoint = torch.load(args.load_path, map_location='cuda', weights_only=False)
        optimizer.load_state_dict(checkpoint['optimizer'])
        start = checkpoint['epoch'] + 1
    
    if args.scheduler:
        from train_utils import CosineAnnealingWarmupRestarts
        lrscheduler = CosineAnnealingWarmupRestarts(optimizer, max_lr=args.lr, first_cycle_steps= sample_size//args.bs * args.epoch, warmup_steps= (sample_size*args.epoch)//(args.bs*20), last_epoch = start*sample_size//args.bs-1)
    else:
        lrscheduler = None

    if args.kf:
        print(f"Using fixed kappa={args.kappa} in KFOptimizer")
        optimizer = wrap_optimizer(optimizer=optimizer, kappa=args.kappa, gamma=args.gamma)
    
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    if args.clipping:
        from fastDP import PrivacyEngine
    
        privacy_engine = PrivacyEngine(
            model,
            noise_multiplier=noise,
            numerical_stability_constant=1e-8,
            sample_size=sample_size,
            batch_size=args.bs,
            epochs=args.epoch,
            torch_seed_is_fixed=False,
            clipping_fn=args.clipping_fn,
            clipping_style=args.clipping_style,
            max_grad_norm=args.clipping_norm,
            num_GPUs=1,  # Make sure this is set correctly
        )
        privacy_engine.attach(optimizer)

    if args.kf and args.load_path is not None:
        optimizer.load_state_dict(checkpoint['kf_optimizer'])

    for E in range(start, args.epoch):
        if args.data == 'imgnet1k':
            train_dl = train_dl_1()
            test_dl = test_dl_1()
        if use_manual_noise:
            noisy_train(model, train_dl, optimizer, criterion, log_file, device = device, epoch = E, noise = noise, log_frequency = args.log_freq, acc_step = acc_step,lr_scheduler=lrscheduler)
        train(model, train_dl, optimizer, criterion, log_file, device = device, epoch = E, log_frequency = args.log_freq, acc_step = acc_step, lr_scheduler=lrscheduler)
        test(model, test_dl, criterion, log_file, device = device, epoch = E)
        if args.data == 'imgnet1k':
            del train_dl
            del test_dl
        gc.collect()
        torch.cuda.empty_cache()
        if args.save_freq > 0 and E % args.save_freq == 0 and args.save_path is not None:
            # Resolve save path to a concrete filename. If args.save_path is a directory (or ends with '/'),
            # save per-epoch checkpoint files inside it (e.g., checkpoint_epoch0.pt).
            save_dest = args.save_path
            if os.path.isdir(save_dest):
                save_file = os.path.join(save_dest, f'checkpoint_epoch{E}.pt')
            else:
                if save_dest.endswith(os.sep):
                    os.makedirs(save_dest, exist_ok=True)
                    save_file = os.path.join(save_dest, f'checkpoint_epoch{E}.pt')
                else:
                    parent = os.path.dirname(save_dest)
                    base = os.path.basename(save_dest)
                    name, ext = os.path.splitext(base)
                    if ext == '':
                        # treat as a directory path
                        os.makedirs(save_dest, exist_ok=True)
                        save_file = os.path.join(save_dest, f'checkpoint_epoch{E}.pt')
                    else:
                        if parent != '':
                            os.makedirs(parent, exist_ok=True)
                        save_file = os.path.join(parent, f'{name}_epoch{E}{ext}')

            if args.kf:
                torch.save({'model':model.state_dict(),'kf_optimizer':optimizer.state_dict(), 'optimizer':optimizer.original_optimizer.state_dict(), 'epoch':E}, save_file)
            else:
                torch.save({'model':model.state_dict(),'optimizer':optimizer.state_dict(), 'epoch':E}, save_file)
            print(f"Saved checkpoint to {save_file}")
        
        
