import os
import warnings
import json
import logging

import torch
from pytorch_lightning import seed_everything

from equislt.args import parse_args_prune
from equislt.methods import PRUNE_METHODS
from equislt.data import prepare_data


def main():
    args = parse_args_prune()
    seed_everything(args.seed)

    # LOADING TARGET NETWORK
    target_path = args.target_net_dir / 'best.ckpt'
    workdir = args.target_net_dir / f'pruning-C-{args.overparam_factor}-seed-{args.seed}'
    os.makedirs(workdir, exist_ok=True)
    source_path = workdir / 'source.ckpt'

    json_path = args.target_net_dir / "args.json"
    with open(json_path, 'r') as f:
        train_args = json.load(f)

    json_path = workdir / "args.json"
    with open(json_path, 'w') as f:
        json.dump(vars(args), f, default=lambda o: "<not serializable>")

    # DATASET
    device = None
    if args.gpus is not None:
        try:
            device = torch.device('cuda:' + str(args.gpus[0]))
        except TypeError:
            device = torch.device('cuda:' + str(args.gpus))
    train_loader, val_loader, test_loader, datametadata = prepare_data(
        args.dataset,
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        device=device,
    )
    del train_loader
    del val_loader

    # METHOD & ARCHITECTURE
    MethodClass = PRUNE_METHODS[args.method]
    model = MethodClass(target_path, train_args, source_path,
                        **vars(args))

    pruning_results = model.pruning_results
    testing_results = model.test_src_model(test_loader, device=device)
    print('Testing results:', testing_results)

    results = vars(args)
    results.update(pruning_results)
    results.update(testing_results)

    res_path = workdir / "results.json"
    with open(res_path, 'w') as f:
        json.dump(results, f, default=lambda o: "<not serializable>")


if __name__ == "__main__":
    main()
