#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import pathlib
import numpy as np
import torch
import warnings
from torch.utils.data import DataLoader
from helpers import get_neigh, get_last_hparam
from train import train
from folders import folders
from data import ImageFolder
from evaluate import save_BSD
from get_module import get_module, get_shifts


import faulthandler
faulthandler.enable()


def main(version, action, neigh=99, learning_rate=0.001, epochs=1, crop_size=256,
         batch_size=32, t_max=np.inf, device='cpu', cont=False,
         num_workers=0, verbose=False, downsample=0, model='pixel',
         pars_train='pred', loss_type='shuffle', momentum=0,
         separate=False, interpolate=False, weight_decay=10**-4, split='train',
         n_report=10000, n_acc=0, transform='quant', lr_fact=10, q_zero=0.3,
         **kwargs):
    folder = os.path.join(folders['models'], model, 'version%d' % version)
    shift, subsamp = get_shifts(model)
    if neigh == 99:
        hparam_f = os.path.join(folder, 'hparam.json')
        if os.path.isfile(hparam_f):
            hparam = get_last_hparam(hparam_f)
            neigh = hparam["neigh"]
        else:
            neigh = 4
            warnings.warn('no hparam found, using neigh=4')
    neighbors = get_neigh(neigh)
    module, dim_out = get_module(model, neighbors)
    if not os.path.isdir(folder):
        pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
    if action == 'train':
        if cont:
            start_epoch = 0
            while os.path.exists(os.path.join(
                    folder, 'checkpoints', 'losses_%d.npy' % start_epoch)):
                start_epoch += 1
            pars_file = os.path.join(folder, 'pars.pth')
            if os.path.isfile(pars_file):
                module.load_state_dict(torch.load(
                    pars_file, map_location=torch.device(device)))
            elif start_epoch >= 1:
                module.load_state_dict(torch.load(
                    os.path.join(
                        folder,
                        'checkpoints', 'cp_%d.pth' % (start_epoch - 1)),
                    map_location=torch.device(device)))
            else:
                warnings.warn('No module checkpoint found! Continuing with random init.')
        else:
            start_epoch = 0
        data = ImageFolder(folders['train_images'], crop_size)
        data_loader = DataLoader(data, batch_size=batch_size, shuffle=True,
                                 num_workers=num_workers)
        train(module, data_loader, folder, lr=learning_rate, n_epoch=epochs,
              max_samples=t_max, device=device, start_epoch=start_epoch,
              pars_train=pars_train, loss_type=loss_type, momentum=momentum,
              weight_decay=weight_decay, n_report=n_report, n_acc=n_acc,
              lr_fact=lr_fact, neigh=neigh, **kwargs)
    elif action == 'save_BSD':
        out_dir_BSD = os.path.join(folders['BSD_predictions'], model, 'version%d' % version,
                                   split + '_%s' % transform)
        pars_file = os.path.join(folder, 'pars.pth')
        if os.path.isfile(pars_file):
            module.load_state_dict(torch.load(
                pars_file, map_location=torch.device(device)))
        else:
            warnings.warn('No module checkpoint found! Training readout for init weights!')
        save_BSD(
              module, shift, subsamp, transform=transform,
              out_dir=out_dir_BSD, in_dir=os.path.join(folders['BSD'], split),
              verbose=verbose, t_max=t_max, num_workers=num_workers, q_zero=q_zero,
              downsample=downsample, separate=separate, interpolate=interpolate)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--device",
                        help="device to run on [cuda,cpu]",
                        choices=['cuda', 'cpu'], default='cpu')
    parser.add_argument("-E", "--epochs",
                        help="numer of epochs",
                        type=int, default=1)
    parser.add_argument("-b", "--batch_size",
                        help="size of a batch",
                        type=int, default=12)
    parser.add_argument("-c", "--crop_size",
                        help="size image crop used for training",
                        type=int, default=256)
    parser.add_argument("--num_workers",
                        help="number of data_loader workers",
                        type=int, default=0)
    parser.add_argument("--downsample",
                        help="downsampling factor [no downsampling]",
                        type=int, default=0)
    parser.add_argument("-r", "--learning_rate",
                        help="learning rate",
                        type=float, default=10**-3)
    parser.add_argument("--lr_fact",
                        help="lr factor for log_c and prior_w",
                        type=float, default=10)
    parser.add_argument("-w", "--weight_decay",
                        help="weight decay rate",
                        type=float, default=10**-4)
    parser.add_argument("--momentum",
                        help="momentum",
                        type=float, default=0)
    parser.add_argument("--noise",
                        help="amount of noise added to the features before prediction",
                        type=float, default=0)
    parser.add_argument("-q", "--q_zero",
                        help="quantile to set to zero for normalizing connectivity maps",
                        type=float, default=0.15)
    parser.add_argument("-s", "--t_max",
                        help="number of training steps",
                        type=int, default=np.inf)
    parser.add_argument("--n_report",
                        help="report how often? In images [default=10000]",
                        type=int, default=10000)
    parser.add_argument("--n_acc",
                        help="How many loss evaluations to accumulate per batch [default=0]",
                        type=int, default=0)
    parser.add_argument("--n_pos",
                        help="How many positions to use as negative set [default=10]",
                        type=int, default=10)
    parser.add_argument("-v", "--version", type=int, default=0,
                        help="which version folder to use")
    parser.add_argument("-n", "--neigh", type=int, default=99,
                        choices=[4, 8, 12, 20, 28, 40, 52, 68, 99],
                        help="how many neighbors to use")
    parser.add_argument("-p", "--pars_train",
                        help="which pars to train ['pred', 'other', 'all']",
                        choices=['pred', 'other', 'all'],
                        default='all')
    parser.add_argument("-t", "--transform",
                        help="which transform to edge_weights ['quant', 'expit']",
                        choices=['expit', 'quant'],
                        default='quant')
    parser.add_argument("--split",
                        help="which data split to work on. Applies only to evaluations",
                        choices=['train', 'val', 'test'],
                        default='train')
    parser.add_argument("-l", "--loss_type",
                        help="which loss to use",
                        choices=['shuffle', 'batch', 'pos', 'pos2', 'pos3'],
                        default='shuffle')
    parser.add_argument("--cont", action='store_true')
    parser.add_argument("--verbose", action='store_true')
    parser.add_argument("--separate", action='store_true')
    parser.add_argument("--interpolate", action='store_true')
    parser.add_argument("-a", "--action",
                        help="what to do? [train, train_readout, save_eval, eval"
                        + ", reset, save_BSD]",
                        choices=['train', 'train_readout', 'save_eval',
                                 'eval', 'reset', 'save_BSD'],
                        default='train')
    parser.add_argument("-m", "--model",
                        help="which model to train [pixel, linear, linear3, \
                        linearbig, conv1, resdl, predseg1]",
                        choices=['pixel', 'linear', 'linear3', 'linearbig',
                                 'conv1', 'resdl', 'predseg1'],
                        default='pixel')
    args = parser.parse_args()
    main(**vars(args))
