import argparse
import ctypes
import multiprocessing as mp
import os
import pickle
import traceback

import pandas as pd
import pytorch_lightning as pl
# Make module 'Models' visible for import
import sys
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from time import time
from tqdm import tqdm

sys.path.insert(1, '../')

from Models.pl_gnn import LitGNNModel
from Models.pl_rnn import LitRNNModel
from data.DataPrepare import load_and_prep_df


def parse_args():
    parser = argparse.ArgumentParser()
    # parser.add_argument("--log_dir", type=str, default='../logs/')
    parser.add_argument("--checkpoint_folder", type=str)
    parser.add_argument("--force", action='store_true')
    parser.add_argument("--gnn", action='store_true')
    parser.add_argument('--save_every', help='Save to file every number of steps', default=10)
    # parser.add_argument('--processes', help='Save to file every number of steps', default=10)
    return parser.parse_args()


def plot_prediction(y_test, y_hat):
    plt.plot(y_test[:, 1], label='data')
    plt.plot(y_hat[:, 1], label='prediction')
    plt.legend()
    plt.show()


def get_x_y_from_df(_df):
    x = _df[:-1].values
    y = _df['venues'][1:].values
    return x, y


def get_tabular_data(filename='data/dataset_bounded.csv', start_year=2019, end_year=2021):
    train, test, min, maxs = load_and_prep_df(filename=filename, normalize_visitors=False, start_year=start_year,
                                              end_year=end_year)
    x_test, y_test = get_x_y_from_df(test)
    _, test_normalized, _, _ = load_and_prep_df(filename=filename, normalize_visitors=True, start_year=start_year,
                                                end_year=end_year)
    x_test_norm, y_test_norm = get_x_y_from_df(test_normalized)

    num_venues = len(test.venues.columns)
    train_min, train_max = train['venues'].min().values, train['venues'].max().values
    return {'x_test':     x_test, 'y_test': y_test, 'x_test_norm': x_test_norm, 'y_test_norm': y_test_norm,
            'num_venues': num_venues, 'train_min': train_min, 'train_max': train_max}


if __name__ == '__main__':
    args = parse_args()
    if args.gnn:
        log_dir = '../logs_gnn/'
    else:
        log_dir = '../logs/'
    all_data = get_tabular_data()

    df = pd.read_csv(log_dir + 'tb_data.csv', index_col=0)

    if args.checkpoint_folder:
        folders = [args.checkpoint_folder]
    else:
        folders = df.index

    # check if output file exists and if so load it
    if os.path.isfile(log_dir + 'tb_data_errors.csv'):
        df_out = pd.read_csv(log_dir + 'tb_data_errors.csv', index_col=0)
    else:
        df_out = pd.DataFrame(columns=df.columns)

    # Shared Multiprocessing Objects
    shared_df = mp.Value(ctypes.py_object)
    shared_df.value = df_out
    shared_count = mp.Value(ctypes.c_int)
    shared_count.value = 0


    def process_folder(idx):
        """
        Process a single experiment folder
        :param idx:
        :return:
        """
        _df = df.loc[idx]
        if idx in df_out.index and not args.force:
            tqdm.write(f"Skipping existing: {idx}")
            return
        shared_df.value.loc[idx] = df.loc[idx]
        try:
            experiment_folder = _df.name
            checkpoint_folder = experiment_folder + '/checkpoints'
            if not os.path.isdir(checkpoint_folder):
                checkpoint_folder = experiment_folder + '/../checkpoints'
                if not os.path.isdir(checkpoint_folder):
                    tqdm.write("! Unable to find checkpoints for: " + experiment_folder)
                    return
            version_str = experiment_folder.split('/')[-1]
            try:
                version = int(version_str)
            except ValueError:
                try:
                    version = int(version_str.split('_')[-1])
                except ValueError:
                    tqdm.write('! Unable to get Version from path: ' + experiment_folder)
                    return

            checkpoint_file = os.path.join(checkpoint_folder, f'epoch=299-step=5400-v{version}.ckpt')
            if not os.path.isfile(checkpoint_file):
                checkpoint_file = os.path.join(checkpoint_folder, 'epoch=299-step=5400.ckpt')
            if not os.path.isfile(checkpoint_file):
                checkpoint_file = '../' + checkpoint_file
            if not os.path.isfile(checkpoint_file):
                checkpoint_file = os.path.join(checkpoint_folder, os.listdir(checkpoint_folder)[-1])
            if not os.path.isfile(checkpoint_file):
                tqdm.write("! Cant eval model because no checkpoint file found")
                return
            tqdm.write("evaluating " + checkpoint_file)

            # Run Validation
            norm_visitors = _df.get('norm_visitors', False) or _df.get('normalize_data', False)
            if args.gnn:

                model = LitGNNModel.load_from_checkpoint(checkpoint_file, norm_visitors=norm_visitors, strict=False,
                                                         prediction_file=os.path.join(log_dir, 'eval',
                                                                                      idx.replace(log_dir, '')))
            else:
                model = LitRNNModel.load_from_checkpoint(checkpoint_file, norm_visitors=norm_visitors, strict=False,
                                                         prediction_file=os.path.join(log_dir, 'eval',
                                                                                      idx.replace(log_dir, '')))
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            model.to(device)
            model.eval()
            trainer = pl.Trainer(detect_anomaly=True,
                                 auto_lr_find=False,
                                 devices=1,
                                 accelerator=device,
                                 # logger=CSVLogger(os.path.join(log_dir, 'eval'), name='idx')
                                 )
            # logger=TensorBoardLogger(out_folder,
            #                          default_hp_metric=False,
            #                          name=model.recurrent.__class__.__name__,
            #                          version='seq' + str(seq_len) + '_glob' + str(
            #                              global_node_features) + (
            #                                      '_norm' if norm_visitors else '') + f"_{str(i)}"))
            start_time = time()
            model.loss = F.mse_loss
            loss = trainer.validate(model)
            shared_df.value.loc[idx, 'mse'] = loss[0]['loss/val']
            model.loss = F.l1_loss
            loss = trainer.validate(model)
            shared_df.value.loc[idx, 'mae'] = loss[0]['loss/val']
            elapsed = time() - start_time
            shared_df.value.loc[idx, 'time'] = elapsed

        except Exception:
            tqdm.write("Error while evaluating {}:".format(idx))
            tqdm.write(traceback.format_exc())
        shared_count.value += 1
        if shared_count.value == args.save_every:
            if not shared_df.value.empty:
                shared_df.value.to_csv(log_dir + 'tb_data_errors.csv')
                tqdm.write("Saved to file")
            shared_count.value = 0


    # with mp.Pool(processes=args.processes) as p:
    #     max_ = len(folders)
    #     with tqdm(total=max_) as pbar:
    #         for _ in p.imap_unordered(process_folder, folders):
    #             pbar.update()

    for folder in tqdm(folders):
        process_folder(folder)

    shared_df.value.to_csv(log_dir + 'tb_data_errors.csv')

    print()
    print("Done.")
