import copy
import torch
import numpy
import scipy
import argparse
import os
import train_func
import cubic_func
import submission.settings.data as data
from tqdm import tqdm

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, help='random_seed')
args = parser.parse_args()
torch.manual_seed(args.seed)
numpy.random.seed(args.seed)


if __name__ == '__main__':
    cfg = argparse.Namespace(batch_size=128, gamma=0.001, lr=0.005, n_epoch=30)
    train_set = data.FashionMNIST(train=True)
    test_set = data.FashionMNIST(train=False)
    model = data.FashionMNIST_ConvNet()
    print(cfg)
    train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=cfg.batch_size)
    test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=cfg.batch_size)
    init_state_dict = copy.deepcopy(model.state_dict())
    model.to(DEVICE)
    loss_fn = torch.nn.CrossEntropyLoss()
    print('params: %d' % sum(p.numel() for p in model.parameters() if p.requires_grad))

    all_logs = []

    print('========= TRAINING & LOGS RANK DYNAMICS ==========')
    sgd_optimizer = torch.optim.SGD(model.parameters(), lr=cfg.lr)
    scheduler = torch.optim.lr_scheduler.MultiplicativeLR(sgd_optimizer, lambda epoch: 0.5)
    for epoch in tqdm(range(cfg.n_epoch), desc='train'):
        model.train()
        print('Epoch {}:'.format(epoch))
        for i, batch in enumerate(train_loader):
            verbose = True if i % 100 == 0 else False
            train_func.generic_step(cfg, model, loss_fn, batch, sgd_optimizer, verbose=verbose)
        if epoch % 10 == 0:
            scheduler.step()

        tuple_params = tuple(p for p in model.parameters() if p.requires_grad)
        H = cubic_func.hessian(cfg, model, loss_fn, train_loader)
        H = cubic_func.compose_param_matrix(H, tuple_params).detach().cpu().numpy()
        rk = numpy.linalg.matrix_rank(H, hermitian=True)

        model.eval()
        test_acc = train_func.test(model, test_loader)
        print('test_acc:', test_acc)
        print('hessian_rk:', rk)
        if epoch == cfg.n_epoch - 1:
            eigenvals = scipy.linalg.eigvalsh(H)
            all_logs.append(vars(argparse.Namespace(epoch=epoch, rank=rk, test_acc=test_acc, eigenvals=eigenvals)))
        else:
            all_logs.append(vars(argparse.Namespace(epoch=epoch, rank=rk, test_acc=test_acc)))
        os.makedirs(f'appx_rank_dynamics/results/fmnist_{args.seed}', exist_ok=True)
        torch.save(all_logs, f'appx_rank_dynamics/results/fmnist_{args.seed}/results.pth')