import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import sys
sys.path.append('/playpen-raid/Author/LucidAtlas/')
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi']= 300
import argparse
from pipeline.load import *
import torch.utils.data as data_utils
import model.networks.basics.workspace as ws
import pipeline

if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        default="/playpen-raid/Author/LucidAtlas/configs/ToyData/cv5_sc_toy/lucidatlas_full.json",
        #default="/playpen-raid/Author/LucidAtlas/configs/OASISBrain/cv5_sc/lucidatlas_full.json",
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=0,
        type=int,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=0,
        type=int,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--ood",
        dest="whether_ood",
        default=0,
        type=int,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--ind",
        dest="whether_ind",
        default=0,
        type=int,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--intr",
        dest="whether_intr",
        default=0,
        type=int,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=0,
        type=int,
        help="whether to vis",
    )

    arg_parser.add_argument(
        "--vis_intr",
        dest="whether_vis_intr",
        default=1,
        type=int,
        help="whether to vis",
    )

    arg_parser.add_argument(
        "--fold",
        "-f",
        dest="which_fold",
        default=0,
        help="whether to vis",
    )


    args = arg_parser.parse_args()

    if bool(args.whether_train):
        pipeline.train_per_fold(args.experiment_directory, cv_idx=args.which_fold)

    if bool(args.whether_vis_intr):
        pipeline.visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='indp', type_of_vis='1d')
        pipeline.visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='', type_of_vis='1d')

    if bool(args.whether_vis):
        pipeline.visualize_correlation(args.experiment_directory, cv_idx=args.which_fold)

        pipeline.visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='global', type_of_vis='1d')
        pipeline.visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='global', type_of_vis='2d')
        pipeline.visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='indp', type_of_vis='2d')
        pipeline.visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='local', type_of_vis='2d')


        pipeline.visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='indp', type_of_vis='1d')
        pipeline.visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='local', type_of_vis='1d')

        pipeline.visualize_interpret_feat_with_ood(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='global', type_of_vis='1d_with_ood')
        pipeline.visualize_interpret_feat_with_ood(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='indp', type_of_vis='1d_with_ood')


        # pipeline.visualize_global(args.experiment_directory, cv_idx=args.which_fold)
        # pipeline.visualize_local(args.experiment_directory, cv_idx=args.which_fold)
        #pipeline.visualize_3d_global(args.experiment_directory, cv_idx=args.which_fold)
        #pipeline.visualize_3din2d_global(args.experiment_directory, cv_idx=args.which_fold)
        #pipeline.visualize_indp(args.experiment_directory, cv_idx=args.which_fold)


    if bool(args.whether_ood):
        pipeline.pred_and_eval_ood_dec(args.experiment_directory, which_set='test', cv_idx=args.which_fold)

    if bool(args.whether_ind):
        pipeline.growth_velocity_with_distribution(args.experiment_directory, cv_idx=args.which_fold)
        pipeline.growth_velocity_with_mean(args.experiment_directory, cv_idx=args.which_fold)

    if bool(args.whether_test):

        pipeline.pred_and_eval(args.experiment_directory, which_set='train', cv_idx=args.which_fold)
        pipeline.pred_and_eval(args.experiment_directory, which_set='test', cv_idx=args.which_fold)

        pipeline.growth_velocity_with_distribution(args.experiment_directory, cv_idx=args.which_fold)
        pipeline.growth_velocity_with_mean(args.experiment_directory, cv_idx=args.which_fold)

        pipeline.pred_and_eval_ood_dec(args.experiment_directory, which_set='test', cv_idx=args.which_fold)

        pipeline.pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                     which_set='test',
                                     cv_idx=args.which_fold,
                                     type_of_interpret='global',
                                     num_of_covariates="n-1",
                                     )

        pipeline.pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                     which_set='test',
                                     cv_idx=args.which_fold,
                                     type_of_interpret='indp',
                                     num_of_covariates="n-1",
                                     )


        pipeline.pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                     which_set='test',
                                     cv_idx=args.which_fold,
                                     type_of_interpret='global',
                                     num_of_covariates="1",
                                     )

        pipeline.pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                     which_set='test',
                                     cv_idx=args.which_fold,
                                     type_of_interpret='indp',
                                     num_of_covariates="1",
                                     )


    if bool(args.whether_intr):

        pipeline.pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                              which_set='test',
                                              cv_idx=args.which_fold,
                                              type_of_interpret='global',
                                              num_of_covariates="1",
                                              )

        pipeline.pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                              which_set='test',
                                              cv_idx=args.which_fold,
                                              type_of_interpret='indp',
                                              num_of_covariates="1",
                                              )


        pipeline.pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                              which_set='test',
                                              cv_idx=args.which_fold,
                                              type_of_interpret='global',
                                              num_of_covariates="n-1",
                                              )

        pipeline.pred_and_eval_feat_interpret(specs_filename=args.experiment_directory,
                                              which_set='test',
                                              cv_idx=args.which_fold,
                                              type_of_interpret='indp',
                                              num_of_covariates="n-1",
                                              )



    print('1')




