import os
import random
import torch
import torch.nn as nn
import argparse
import numpy as np
import json
from tqdm import tqdm
from feature_dataset import FeaturesDataset, collate
from torch.utils.data import DataLoader
from model import Trainer

import gc
import pdb



parser = argparse.ArgumentParser()
parser.add_argument('--device',
                    help='what device to perform training on',
                    type=str,
                    default='cuda:0')
parser.add_argument('--data_path',
                    help='dir of features',
                    type=str,
                    default='../')
parser.add_argument('--ckpt_dir',
                    help='dir for checkpoint to save',
                    type=str,
                    default='checkpoint')
parser.add_argument('--features_dim',
                    help='number of features dimension',
                    type=int,
                    default=1024)
parser.add_argument('--num_f_maps', default='64', type=int)

parser.add_argument('--num_classes',
                    help='total number of different steps',
                    type=int,
                    default=779)
parser.add_argument('--num_epochs',
                    help='number of epochs to train for',
                    type=int,
                    default=100)
parser.add_argument('--batch_size',
                    help='training batch size',
                    type=int,
                    default=4)
parser.add_argument('--lr',
                    help='learning rate',
                    type=float,
                    default=2e-3)
parser.add_argument('--seed',
                    help='random seed',
                    type=int,
                    default=0)
parser.add_argument('--num_workers',
                    help='number of workers for DataLoader',
                    type=int,
                    default=16)
parser.add_argument('--action', help='train or predict', default='train')
parser.add_argument('--causal', help='use causal intervention or not', default='yes')
parser.add_argument('--pred_epochs',
                    help='number of epochs to train for',
                    type=str,
                    default='100')
# Need input
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--conv_len', type=int, default=25)
args = parser.parse_args()
print(args)


device = torch.device(args.device if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print("device: ", device)

if args.seed is not None:
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

ckpt_dir = args.ckpt_dir

num_epochs = args.num_epochs
features_dim = args.features_dim
batch_size = args.batch_size
lr = args.lr

num_layers = args.num_layers
conv_len = args.conv_len
num_f_maps = args.num_f_maps
num_classes = args.num_classes

if args.causal == 'yes':
    causal = True
else:
    causal = False

trainer = Trainer(num_layers, conv_len, num_f_maps, features_dim, num_classes)

if args.action == "train":
    os.makedirs(args.ckpt_dir, exist_ok=True)
    with open(os.path.join(args.ckpt_dir, 'conf.json'), 'w') as conf_json:
        json.dump(args.__dict__, conf_json)
    print("##### preparing train dataset #####")
    dataset = FeaturesDataset(args.data_path, train=True)
    data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True,
                             num_workers=args.num_workers, pin_memory=False if device is 'cpu' else True,
                             collate_fn=collate)

    trainer.train(ckpt_dir, data_loader, num_epochs=num_epochs, learning_rate=lr, device=device, causal=causal)

if args.action == "predict":
    print("##### preparing test dataset #####")
    dataset = FeaturesDataset(args.data_path, train=False)
    data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False,
                             num_workers=args.num_workers, pin_memory=False if device is 'cpu' else True)
    epochs = args.pred_epochs.split(',')
    for epoch in epochs:
        trainer.predict(ckpt_dir, data_loader, int(epoch), device)

