import torch
import numpy
import json
import os
import argparse
import submission.settings.data as data
import methods
import train_func

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, help='random_seed')
parser.add_argument('--gamma', type=float, help='damping factor')
args = parser.parse_args()
torch.manual_seed(args.seed)
numpy.random.seed(args.seed)

def evaluate(df_set, df_all_set, dr_set, test_set, model):
    df_acc = train_func.test(model, loader=torch.utils.data.DataLoader(df_set, batch_size=min(256, len(df_set))))
    df_all_acc = train_func.test(model, loader=torch.utils.data.DataLoader(df_all_set, batch_size=min(256, len(df_all_set))))
    dr_acc = train_func.test(model, loader=torch.utils.data.DataLoader(dr_set, batch_size=256))
    dtest_acc = train_func.test(model, loader=torch.utils.data.DataLoader(test_set, batch_size=256))
    return {'df_all_acc': df_all_acc, 'df_acc': df_acc, 'dr_acc': dr_acc, 'dtest_acc': dtest_acc}

if __name__ == '__main__':
    model_path = f'sequential_unlearning/class_results/fmnist_{args.seed}/fully_train.pth'
    unlearn_order_path = f'sequential_unlearning/class_results/fmnist_{args.seed}/unlearn_order.pth'
    cfg = argparse.Namespace(
        batch_size=64, reg_weight=0.001, gamma=args.gamma,
        model_path=model_path, unlearn_order_path=unlearn_order_path)
    train_set = data.FashionMNIST(train=True)
    test_set = data.FashionMNIST(train=False)
    model = data.FashionMNIST_ConvNet()
    model.load_state_dict(torch.load(cfg.model_path)['state_dict'])
    model.to(DEVICE)
    loss_fn = torch.nn.CrossEntropyLoss()
    print('params: %d' % sum(p.numel() for p in model.parameters() if p.requires_grad))
    sel = torch.load(open(cfg.unlearn_order_path, 'rb'))
    tot_sel = len(sel)
    erase_bs = 100

    print('========= VARYING DAMPING FACTOR (SEQUENTIAL UNLEARNING) =========')
    from_i = 0
    round = 0
    all_logs = []
    dr_indices = list(range(len(train_set)))
    while True:
        if from_i >= tot_sel:
            break
        to_i = min(from_i+erase_bs, tot_sel)
        round += 1
        print('unlearn_round:', round)
        df_indices = sel[from_i:to_i]
        dr_indices = list(set(dr_indices).difference(df_indices))
        df = torch.utils.data.Subset(train_set, df_indices) 
        dr = torch.utils.data.Subset(train_set, dr_indices)
        df_all = torch.utils.data.Subset(train_set, sel[:to_i])
        methods.damped_newton(df, dr, model, cfg)
        results = evaluate(df, df_all, dr, test_set, model)
        print(results, end='\n')
        logs = argparse.Namespace(round=round, from_i=from_i, to_i=to_i, **results)
        all_logs.append(vars(logs))
        from_i += erase_bs
        os.makedirs(f'varying_gamma/results/fmnist_{args.seed}', exist_ok=True)
        with open(f'varying_gamma/results/fmnist_{args.seed}/gamma={cfg.gamma}_results.json', 'w') as f:
            json.dump(all_logs, f, indent=2)