import pandas as pd

import os,sys

sys.path.append(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))


from parse_args_tabular import parse_args
from tabular_data_utils.tabular_dataset import *
import synthetic_lang
from create_language import Language
from utils_treatment import load_configs, transform_treatment_ids, set_lang_data
import torch
from tab_models.tab_model import *
from utils_treatment_tab import load_tab_data

import structured_experiments.sw as sw
import structured_experiments.tcga as tcga


if __name__ == "__main__":
    args, parser = parse_args()
    
   
    # if args.dataset_name == "aicc":
    #     all_data = load_aicc_dataset(args.data_folder).dropna()
    #     train_df, valid_df, test_df = split_train_valid_test_df(all_data)
    #     id_attr = "index"
    #     outcome_attr = "y"
    #     treatment_attr = "z"
    # elif args.dataset_name == "ihdp":
    #     train_df, valid_df, test_df, all_data = load_ihdp_dataset()
    #     id_attr = "id"
    #     outcome_attr = "outcome"
    #     treatment_attr = "Treatment"
    # lang = Language(all_data, id_attr, outcome_attr, treatment_attr, None, precomputed=None, lang=synthetic_lang)
    if args.dataset_name == "sw":
        sw.add_params(parser)
    elif args.dataset_name == "tcga_str":
        tcga.add_params(parser)
    args = parser.parse_args()
    # train_dataset, valid_dataset, test_dataset, feat_range_mappings = create_dataset(all_data, train_df, valid_df, test_df, synthetic_lang, id_attr, outcome_attr, treatment_attr, synthetic_lang.DROP_FEATS)
    all_data, train_df, valid_df, test_df, lang, id_attr, outcome_attr, treatment_attr, count_outcome_attr, dose_attr, normalize_y, extra_info = load_tab_data(args)
    train_dataset, valid_dataset, test_dataset, feat_range_mappings = create_dataset(args.dataset_name, all_data, train_df, valid_df, test_df, synthetic_lang, id_attr, outcome_attr, treatment_attr, synthetic_lang.DROP_FEATS, count_outcome_attr=count_outcome_attr, dose_attr=dose_attr, normalize_y=normalize_y, extra_info=extra_info, structured_treat=args.structured_treatment)
    if args.cat_and_cont_treatment:
        args.seed = args.dataset_id
    random_seed = args.seed
    random.seed(random_seed)

    # Set a random seed for NumPy
    np.random.seed(random_seed)

    # Set a random seed for PyTorch
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    if args.structured_treatment:
        lang = set_lang_data(lang, train_dataset, extra_info[4])
    else:
        lang = set_lang_data(lang, train_dataset)
    
    numeric_count  = len(train_dataset.num_cols)
        # numeric_count  = len(train_dataset.num_cols) if "num_feat" in lang.syntax else 0
    category_count = list(train_dataset.cat_cols) #len(lang.syntax["cat_feat"]) if "cat_feat" in lang.syntax else 0
    category_sum_count = train_dataset.cat_sum_count

    input_size = numeric_count + category_sum_count
    root_dir = os.path.dirname(os.path.realpath(__file__))
    args.root_dir = root_dir
    rl_config, model_config = load_configs(args,root_dir=root_dir)
    args.embed_size = model_config["hidden_size"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not args.structured_treatment:
        trainer = dql_algorithm(
            train_dataset, valid_dataset, test_dataset, id_attr, outcome_attr, treatment_attr, lang, args.lr, rl_config["gamma"], args.dropout_p, feat_range_mappings, args.program_max_len, rl_config["replay_memory_capacity"], rl_config, model_config, numeric_count, category_count, category_sum_count, args, topk_act=args.topk_act,
                    batch_size = args.batch_size,
                    modeldir = None, log_folder = args.log_folder
        )
        if args.cached_model_name is not None:
            trainer.dqn.policy_net.load_state_dict(torch.load(args.cached_model_name, map_location=trainer.device))
            trainer.dqn.target_net.load_state_dict(torch.load(args.cached_model_name, map_location=trainer.device))
        
        if args.is_log:
            # if args.dataset_name == "ihdp":
            trainer.dqn.policy_net.load_state_dict(torch.load(os.path.join(args.log_folder, "model_best"), map_location=trainer.device))
            trainer.dqn.target_net.load_state_dict(torch.load(os.path.join(args.log_folder, "model_best"), map_location=trainer.device))
        
        trainer.run(train_dataset, valid_dataset, test_dataset)
    else:
        args.num_treatments = len(extra_info[4])
        trainer = dql_algorithm(
            train_dataset, valid_dataset, (extra_info[2],extra_info[3]), id_attr, outcome_attr, treatment_attr, lang, args.lr, rl_config["gamma"], args.dropout_p, feat_range_mappings, args.program_max_len, rl_config["replay_memory_capacity"], rl_config, model_config, numeric_count, category_count, category_sum_count, args, topk_act=args.topk_act,
                    batch_size = args.batch_size,
                    modeldir = None, log_folder = args.log_folder
        )
        
        if args.cached_model_name is not None:
            trainer.dqn.policy_net.load_state_dict(torch.load(args.cached_model_name, map_location=trainer.device))
            trainer.dqn.target_net.load_state_dict(torch.load(args.cached_model_name, map_location=trainer.device))
        
        trainer.run_structured(train_dataset, valid_dataset, extra_info[2],extra_info[3],extra_info[4], extra_info[5])   
    