import torch
import numpy
import json
import os
import time
import argparse
import submission.settings.data as data
import methods
import train_func
import sequential_unlearning 

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, help='random_seed')
parser.add_argument('--method', type=str, choices=['retrain', 'cr_newton', 'scr_newton'], help='unlearning method')
parser.add_argument('--M', type=float, default=None, help='Hessian Lipschitz constant')
args = parser.parse_args()
torch.manual_seed(args.seed)
numpy.random.seed(args.seed)



if __name__ == '__main__':
    model_path = f'sequential_unlearning/class_results/agnews_{args.seed}/fully_train.pth'
    unlearn_order_path = f'sequential_unlearning/class_results/agnews_{args.seed}/unlearn_order.pth'
    train_cfg = vars(argparse.Namespace(batch_size=1024, reg_weight=2e-4))
    cfg = argparse.Namespace(method=args.method, M=args.M, model_path=model_path, unlearn_order_path=unlearn_order_path)

    # 1. load data and model
    # train_set = settings.FashionMNIST(train=True)
    # test_set = settings.FashionMNIST(train=False)
    # model = settings.FashionMNIST_ConvNet()
    train_set = data.AGNEWS(train=True)
    test_set = data.AGNEWS(train=False)
    model = data.AGNEWS_LSTM()
    ckpt = torch.load(open(cfg.model_path, 'rb'))
    model.load_state_dict(ckpt['state_dict'])
    model.to(DEVICE)
    loss_fn = torch.nn.CrossEntropyLoss()
    print('params: %d' % sum(p.numel() for p in model.parameters() if p.requires_grad))
    
    # 2. load unlearn point
    sel = torch.load(open(cfg.unlearn_order_path, 'rb'))
    tot_sel = len(sel)
    erase_bs = 250

    print('========= VARYING M (SEQUENTIAL UNLEARNING) ==========')
    from_i = 0
    round = 0
    all_logs = []
    dr_indices = list(range(len(train_set)))
    while True:
        if from_i >= tot_sel:
            break
        to_i = min(from_i+erase_bs, tot_sel)
        round += 1
        df_indices = sel[from_i:to_i]
        dr_indices = list(set(dr_indices).difference(df_indices))
        df = torch.utils.data.Subset(train_set, df_indices) 
        dr = torch.utils.data.Subset(train_set, dr_indices)
        df_all = torch.utils.data.Subset(train_set, sel[:to_i])
        start_time = time.time()
        if cfg.method == 'retrain':
            # unlearn_cfg = argparse.Namespace(
            #     n_epoch=30, lr=0.005, init_state_dict=ckpt['init_state_dict'], **train_cfg)
            unlearn_cfg = argparse.Namespace(
                n_epoch=70, lr=0.005, init_state_dict=ckpt['init_state_dict'], **train_cfg)
            methods.retrain(df, dr, model, unlearn_cfg)
        elif cfg.method == 'cr_newton':
            unlearn_cfg = argparse.Namespace(M=cfg.M, **train_cfg)
            alpha = methods.cr_newton(df, dr, model, unlearn_cfg)
        elif cfg.method == 'scr_newton':
            unlearn_cfg = argparse.Namespace(M=cfg.M, L=cfg.M, **train_cfg)
            methods.scr_newton(df, dr, model, unlearn_cfg)
        else: raise Exception
        results = sequential_unlearning.evaluate(df, df_all, dr, test_set, model)
        print(results, end='\n')
        tot_time = time.time() - start_time
        if cfg.method == 'retrain':
            logs = argparse.Namespace(round=round, from_i=from_i, to_i=to_i, time=tot_time, **results)
            method_name = cfg.method
        elif cfg.method == 'cr_newton':
            logs = argparse.Namespace(round=round, from_i=from_i, to_i=to_i, time=tot_time, M=cfg.M, alpha=alpha, **results)
            method_name = cfg.method + f'_M={cfg.M}'
        elif cfg.method == 'scr_newton':
            logs = argparse.Namespace(round=round, from_i=from_i, to_i=to_i, time=tot_time, M=cfg.M, **results)
            method_name = cfg.method + f'_M={cfg.M}'
        all_logs.append(vars(logs))
        from_i += erase_bs

        os.makedirs(f'varying_M/results/agnews_{args.seed}', exist_ok=True)
        with open(f'varying_M/results/agnews_{args.seed}/{method_name}_results.json', 'w') as f:
            json.dump(all_logs, f, indent=2)
