import os
import argparse
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from src.utils import read_dataset_from_npy, Logger
from src.rogra.model import RoGra, RoGraTrainer
from create_dataset import create_dataset

data_dir = './tmp'
log_dir = './logs'

eeg_datasets = ['bci', 'mamem', 'bcicha']
tsc_datasets = []

def train(X, y, train_idx, test_idx, distances, device, logger, args):
    nb_classes = len(np.unique(y, axis=0))

    input_size = X.shape[1]

    model = RoGra(input_size, nb_classes, num_layers=args.num_layers, n_feature_maps=args.hidden_dim, dropout=args.dropout)
    model = model.to(device)
    trainer = RoGraTrainer(device, logger)

    model = trainer.fit(model=model, X=X, y=y, train_idx=train_idx, distances=distances, args=args)
    acc = trainer.test(model, test_idx)

    return acc


def argsparser():
    parser = argparse.ArgumentParser("RoGra")
    parser.add_argument('--dataset', help='Dataset name - Options: bci, mamem, bcicha', default='bci')
    parser.add_argument('--subject', help='Subject ID you want to train', type=int, default=1)
    parser.add_argument('--seed', help='Random seed', type=int, default=0)
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--epochs', help='Number of training epochs', type=int, default=500)
    parser.add_argument('--shot', help='shot', type=int, default=1)
    parser.add_argument('--K', help='K', type=int, default=3)
    parser.add_argument('--alpha', help='alpha', type=float, default=0.3)
    parser.add_argument('--num_layers', help='Number of inception Layers - Options are 1,2 and 3', type=int, default=2)
    parser.add_argument('--hidden_dim', help='Hidden Dimension for the Inception backbone', type=int, default=64)
    parser.add_argument('--bs', help='Batch size', type=int, default=128)
    parser.add_argument('--lr', help='learning rate', type=float, default=1e-4)
    parser.add_argument('--wd', help='weight decay', type=float, default=4e-3)
    parser.add_argument('--dropout', help='dropout rate', type=float, default=0.5)

    return parser

if __name__ == "__main__":
    # Get the arguments
    parser = argsparser()
    args = parser.parse_args()
    # Setup the gpu
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        print("--> Running on the GPU")
    else:
        device = torch.device("cpu")
        print("--> Running on the CPU")

    create_dataset(args, eeg_datasets)

    # Seeding
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.dataset in eeg_datasets:
        dtw_dir = os.path.join(f'datasets/eeg/')
        distances = np.load(os.path.join(dtw_dir, args.dataset + f'/subject_{args.subject}_dtw.npy'))
    else:
        dtw_dir = os.path.join(data_dir, 'ucr_datasets_dtw') 
        distances = np.load(os.path.join(dtw_dir, args.dataset+'.npy'))

    out_dir = os.path.join(log_dir, f'{args.dataset}_subj{args.subject}')
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    out_path = os.path.join(out_dir, f'seed{args.seed}_shot{args.shot}_k{args.K}_a{args.alpha}_lr{args.lr}_wd{args.wd}_dpout{args.dropout}.txt')

    with open(out_path, 'w') as f:
        logger = Logger(f)
        # Read data
        if args.dataset in eeg_datasets:
            X, y, train_idx, test_idx = read_dataset_from_npy(os.path.join(data_dir, 'eeg_datasets_'+str(args.shot)+'_shot', args.dataset+'.npy'))
        else:
            X, y, train_idx, test_idx = read_dataset_from_npy(os.path.join(data_dir, 'ucr_datasets_'+str(args.shot)+'_shot', args.dataset+'.npy'))
            

        # Train the model
        acc = train(X, y, train_idx, test_idx, distances, device, logger, args)

        logger.log('--> {} Test Accuracy: {:5.4f}'.format(args.dataset, acc))
        logger.log(str(acc))
