import torch
import math
from train_utils import test_nlp, train_nlp
from LPSGD import LPSGD
from LMSSGD import LMSSGD
from LPAdamBC import LPAdamBC
from init_utils import lp_parse_args, nlp_task_init, logger_init
from fastDP import PrivacyEngine
#PrivacyEngine_Distributed_extending,PrivacyEngine_Distributed_Stage_2_and_3
# from opacus.accountants.utils import get_noise_multiplier
# from opacus.validators import ModuleValidator
import argparse
import warnings
# import timm
# import os
# from datetime import datetime
# import wandb

if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    parser = argparse.ArgumentParser(description='LP DPSGD')
    parser = lp_parse_args(parser)
    args = parser.parse_args()
    train_dl, test_dl, model, device, sample_size, acc_step, noise, tokenizer = nlp_task_init(args)
    log_file = logger_init(args, noise, sample_size//args.mnbs,type=args.log_type)
    with open(args.coef_file, "r") as f:
        coefs = f.readlines()
        a = [float(i) for i in coefs[0].split(",") if i.strip()]
        b = [float(i) for i in coefs[1].split(",") if i.strip()]

    if args.algo == "sgd":
        optimizer = LPSGD(model.parameters(), lr=args.lr, a=a, b=b)
    elif args.algo == 'adam':
        optimizer = LPAdamBC(model.parameters(), lr=args.lr, a=a, b=b, c=args.beta, sigma = math.pow(noise/args.bs, 2), weight_decay= 1e-3)
        print(math.pow(noise/args.bs, 2))
    elif args.algo == 'LMSSGD':
        optimizer = LMSSGD(model.parameters(), lr=args.lr, sigma=args.noise/math.sqrt(args.bs), n=3, beta = 0.95, beta2= 0.999)
    else:
        print(args.algo)
        raise RuntimeError("Unknown Algorithm!")
    if args.load_path is not None:
        checkpoint = torch.load(args.load_path, map_location='cuda')
        # model.load_state_dict(checkpoint['module'], strict = True)
        optimizer.load_state_dict(checkpoint['optimizer'])
    if args.scheduler:
        # from torch.optim import lr_scheduler
        from train_utils import CosineAnnealingWarmupRestarts
        lrscheduler = CosineAnnealingWarmupRestarts(optimizer, max_lr=args.lr, first_cycle_steps= sample_size//args.bs * args.epoch, warmup_steps= 10)
        if args.load_path is not None:
            lrscheduler.load_state_dict(checkpoint['scheduler'])
    else:
        lrscheduler = None
    
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    if args.clipping:
        privacy_engine = PrivacyEngine(model, noise_multiplier=noise, numerical_stability_constant=1e-3, grad_accum_steps = acc_step, sample_size= sample_size, batch_size=args.bs, epochs= args.epoch, per_sample_clip=args.clipping, torch_seed_is_fixed=False, clipping_fn=args.clipping_fn, clipping_style=args.clipping_style,max_grad_norm=args.clipping_norm)
        privacy_engine.attach(optimizer)

    # use_manual_noise = not args.clipping and noise>0
    for E in range(args.epoch):
        # if args.no_record:
        train_nlp(model, train_dl, optimizer, criterion, log_file, device = device, epoch = E, log_frequency = args.log_freq, acc_step = acc_step, lr_scheduler=lrscheduler)
        test_nlp(model, test_dl, criterion, log_file, device = device, epoch = E)
        if args.save_path is not None and (E+1) % args.save_freq == 0:
            save_dict = {'model':model.state_dict(), 'optimizer':optimizer.state_dict()}
            if lrscheduler is not None:
                save_dict['scheduler'] = lrscheduler.state_dict()
            torch.save(save_dict, args.save_path)
