import os
import math
import copy
import numpy
import torch
import json
import argparse
import train_func
import cubic_func
import submission.settings.data as data
import methods as unlearn_methods
from tqdm import tqdm

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--method', type=str, help='online learning method', 
                    choices=[
                        'cold_start', 'warm_start', 'lure_gd', 'lure_cr_newton', 'lure_scr_newton',
                        'cr_newton_then_learn', 'scr_newton_then_learn',
                    ])
parser.add_argument('--replay_size', type=int, help='size of replay buffer')
parser.add_argument('--seed', type=int, help='random_seed')
args = parser.parse_args()
torch.manual_seed(args.seed)
numpy.random.seed(args.seed)

def split(train_set, tot_split):
    random_indices = numpy.random.permutation(len(train_set))
    size = math.ceil(len(train_set) / tot_split)
    subset_ids = []
    for i in range(tot_split):
        subset_ids.append(random_indices[i*size:min(len(train_set),(i+1)*size)])
    return subset_ids

def learn(train_set, model, cfg):
    loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=cfg.batch_size)
    loss_fn = torch.nn.CrossEntropyLoss()
    sgd_optimizer = torch.optim.SGD(model.parameters(), lr=cfg.lr)
    scheduler = torch.optim.lr_scheduler.MultiplicativeLR(sgd_optimizer, lambda epoch: 0.5)
    for epoch in tqdm(range(cfg.n_epoch), desc='sgd'):
        model.train()
        for batch in loader:
            train_func.generic_step(cfg, model, loss_fn, batch, sgd_optimizer, verbose=False)
        if epoch % 10 == 0:
            scheduler.step()

def influence_func(train_set, test_set, model, cfg):
    model.eval()
    loss_fn = torch.nn.CrossEntropyLoss()
    tuple_params = tuple(p for p in model.parameters() if p.requires_grad)
    train_loader = torch.utils.data.DataLoader(train_set, shuffle=False, batch_size=cfg.batch_size)
    test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=cfg.batch_size)

    test_gradient = cubic_func.gradient(cfg, model, loss_fn, test_loader)
    test_gradient = cubic_func.compose_param_vector(test_gradient, tuple_params).detach().cpu().numpy()
    inv_H = cubic_func.hessian(cfg, model, loss_fn, train_loader)
    inv_H = cubic_func.compose_param_matrix(inv_H, tuple_params).detach().cpu().numpy()
    inv_H = numpy.linalg.pinv(inv_H) 
    influence = []
    for sample in tqdm(train_set, desc='if'):
        loader = torch.utils.data.DataLoader([sample])
        gradient = cubic_func.gradient(cfg, model, loss_fn, loader)
        gradient = cubic_func.compose_param_vector(gradient, tuple_params).detach().cpu().numpy()
        influence.append(-test_gradient.T @ inv_H @ gradient)
    return numpy.array(influence)


if __name__ == '__main__':
    cfg = argparse.Namespace(tot_split=15, replay_size=args.replay_size, method=args.method)
    train_cfg = argparse.Namespace(batch_size=128, reg_weight=0.001, lr=0.005, n_epoch=30, M=80, L=80)

    # 1. load data and model
    train_set = data.FashionMNIST(train=True)
    test_set = data.FashionMNIST(train=False)
    test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=train_cfg.batch_size)
    model = data.FashionMNIST_ConvNet()
    model.to(DEVICE)
    init_state_dict = copy.deepcopy(model.state_dict())

    # 2. partition data
    subset_ids = split(train_set, cfg.tot_split)

    print('======== ONLINE LEARNING =======')
    all_logs = []

    # 3. learn
    round = 0
    cer = 0

    curr_d = torch.utils.data.Subset(train_set, subset_ids[0])
    train_loader = torch.utils.data.DataLoader(curr_d, shuffle=False, batch_size=min(len(curr_d), train_cfg.batch_size))
    learn(curr_d, model, train_cfg)
    train_acc = train_func.test(model, train_loader)
    test_acc = train_func.test(model, test_loader)
    generalization_gap = train_acc - test_acc
    cer += (100 - train_acc) * len(curr_d) / 100
    # all_logs.append(vars(argparse.Namespace(round=0, dtest_acc=test_acc)))
    all_logs.append(vars(argparse.Namespace(round=round, dtest_acc=test_acc, generalization_gap=generalization_gap, cer=cer)))
    print('round', round)
    print('test acc:', test_acc, end='\n')

    for round in range(1, cfg.tot_split):
        prev_d = curr_d
        curr_d = torch.utils.data.Subset(train_set, subset_ids[round])
        all_d = torch.utils.data.Subset(train_set, numpy.concatenate(subset_ids[:round+1]))
        replay_ids = numpy.random.choice(numpy.concatenate(subset_ids[:round]), size=cfg.replay_size, replace=False)
        replay_d = torch.utils.data.Subset(train_set, replay_ids)
        augment_d = torch.utils.data.ConcatDataset((curr_d, replay_d))

        if cfg.method == 'cold_start':
            model.load_state_dict(init_state_dict)
            learn(all_d, model, train_cfg)
        elif cfg.method == 'warm_start':
            learn(augment_d, model, train_cfg)
        elif cfg.method == 'cr_newton_then_learn':
            unlearn_methods.cr_newton(None, prev_d, model, train_cfg)
            learn(augment_d, model, train_cfg)
        elif cfg.method == 'scr_newton_then_learn':
            unlearn_methods.scr_newton(None, prev_d, model, train_cfg)
            learn(augment_d, model, train_cfg)
        else:
            # 4.1. measure influence
            influence = influence_func(prev_d, curr_d, model, train_cfg)
            assert len(influence) == len(prev_d)
            df_ids = numpy.argpartition(-influence, 100)[:100]
            print(influence[df_ids].max(), influence.max())

            # 4.2. unlearn
            df = torch.utils.data.Subset(prev_d, df_ids)
            dr = torch.utils.data.Subset(prev_d, list(set(range(len(prev_d))).difference(df_ids)))
            if cfg.method == 'lure_gd':
                untrain_cfg = argparse.Namespace()
                unlearn_methods.gd(df, dr, model, train_cfg)
            elif cfg.method == 'lure_cr_newton':
                unlearn_methods.cr_newton(df, dr, model, train_cfg)
            elif cfg.method == 'lure_scr_newton':
                unlearn_methods.scr_newton(df, dr, model, train_cfg)

            # 5. relearn
            learn(augment_d, model, train_cfg)

        all_train_loader = torch.utils.data.DataLoader(all_d, shuffle=False, batch_size=min(len(all_d), train_cfg.batch_size))
        train_loader = torch.utils.data.DataLoader(curr_d, shuffle=False, batch_size=min(len(curr_d), train_cfg.batch_size))
        train_acc = train_func.test(model, train_loader)
        all_train_acc = train_func.test(model, all_train_loader)
        test_acc = train_func.test(model, test_loader)
        generalization_gap = all_train_acc - test_acc
        cer += (100 - train_acc) * len(curr_d) / 100
        print('round', round)
        if cfg.method not in ('cold_start', 'warm_start'):
            print('replay size:', len(replay_ids))
        print('test acc:', test_acc, end='\n')
        all_logs.append(vars(argparse.Namespace(round=round, dtest_acc=test_acc, generalization_gap=generalization_gap, cer=cer)))

        save_dir = f'online_learning/results/fmnist_{args.seed}_{args.replay_size}_replay'
        os.makedirs(save_dir, exist_ok=True) 
        with open(f'{save_dir}/{cfg.method}_results.json', 'w') as f:
            json.dump(all_logs, f, indent=2)