"""Each run creates a directory based on current datetime to save:
- log file of training process
- experiment configurations
- data and ground truth
- final estimated solution
"""

import argparse
import logging
import sys
import time

import igraph as ig
import numpy as np

from methods.gies import gies
from methods.jci_gsp import jci_gsp
from methods.method import Method
from methods.igsp import igsp, ut_igsp
from utils.config import save_yaml_config, get_args
from utils.data import simulate_dag, simulate_data
from utils.dir import create_dir, get_datetime_str
from utils.logging import setup_logger, get_system_info
from utils.metrics import count_dag_accuracy
from utils.utils import set_random_seed, plot_graphs
import csv 
import os 

def write_result(results, basics):
    combined_data = {**basics, **results}
    # File to save to
    file_name = "output/results_new_soft_rest.csv"

    file_exists = os.path.exists(file_name)

    # Write to CSV file
    with open(file_name, mode='a+', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=combined_data.keys())
        
        if not file_exists:
            writer.writeheader()
        
        # Write the combined row
        writer.writerow(combined_data)


def main():
    # Get arguments parsed
    args = get_args()
    print(args)
    create_dir('output/')

    # Setup for logging
    # output_dir = 'output/{}'.format(get_datetime_str(add_random_str=True))
    # create_dir(output_dir)    # Create directory to save log files and outputs
    # setup_logger(log_path='{}/training.log'.format(output_dir), level='INFO')
    # _logger = logging.getLogger(__name__)
    # _logger.info("Finished setting up the logger.")

    # # Get and save system info
    # system_info = get_system_info()
    # if system_info is not None:
    #     save_yaml_config(system_info, path='{}/system_info.yaml'.format(output_dir))

    # # Save configs
    # save_yaml_config(vars(args), path='{}/config.yaml'.format(output_dir))

    # # Set random seed
    # set_random_seed(args.seed)

    # # Generate/load data
    # if args.data_type in {'linear', 'mlp'}:
    #     num_edges = int(args.degree / 2 * args.nodenum)
    #     dag_true= simulate_dag(args.nodenum, num_edges, graph_type='ER')

    #     interv_targets = [
    #         tuple(sorted(np.random.choice(range(args.nodenum),
    #                                       np.random.randint(1, args.max_single_intervention_num),
    #                                       replace=False)))
    #         for _ in range(args.num_of_interv_configs)
    #     ]
    #     num_of_selection_configs = np.random.randint(1, args.max_selection_num)
    #     if args.nodenum <= 5:
    #         candidate_nodes_for_selection = list(range(args.nodenum))
    #     else:
    #         # As suggested by Haoyue, we do not let selection happen on causal_ordering[:-3]
    #         G = ig.Graph.Weighted_Adjacency(dag_true.tolist())
    #         ordered_vertices = G.topological_sorting()
    #         candidate_nodes_for_selection = ordered_vertices[:-3]
    #     selection_parents = [
    #         tuple(sorted(np.random.choice(candidate_nodes_for_selection,
    #                                       np.random.randint(1, args.max_single_selection_num),
    #                                       replace=False)))
    #         for _ in range(num_of_selection_configs)]

    #     all_data = simulate_data(dag_true, args.samplenum, interv_targets, selection_parents,
    #                              args.noise_type, args.data_type)
    # elif args.data_type == 'biology':
    #     all_data = None    # Pass in all_data
    #     dag_true = None    # Pass in ground truth DAG. If not, set it to None
    #     interv_targets = None    # Pass in intervention targets. If not, set it to None
    #     selection_parents = None    # Pass in selection targets. If not, set it to None
    # else:
    #     raise ValueError("Unknown data type.")
    # assert len(set([data.shape[1] for data in all_data])) == 1, "The number of variables in each domain should be the same"
    # _logger.info("Shape of data in each domain: {}".format([data.shape for data in all_data]))
    # _logger.info("Selection targets: {}".format(selection_parents))
    # _logger.info("Intervention targets: {}".format(interv_targets))
    # params_true = {'dag': dag_true, 'selection_parents': selection_parents, 'interv_targets': interv_targets,
    #                'all_data_shape': [data.shape for data in all_data]}
    # print("this shape: ", np.array(all_data).shape)


    input = np.load(f"/Users/longkang.li/Desktop/selection/new_dataset/{args.perturbation_type}/v_{args.nodenum}/{args.samplenum}/sample_{args.seed}/sample_{args.perturbation_type}.npz") 
    # print(input.keys())
    all_data = []
    my_interv_targets = []
    node_num = input['obs'].shape[1]

    all_data.append(input['obs'])
    for i in range(node_num):
        name = f"per_{i}"
        # print(i, input[name].shape)
        my_interv_targets.append((i,))
        all_data.append(input[name])
    # for key in input.keys():
    #     print(key)
        
    # interv_targets = None 
    # print(f"load data success! interv_targets = {my_interv_targets} | my data shape: {np.array(all_data).shape}")

    # Run method
    if args.method_type == 'ours':
        # all_data is a list of numpy arrays with shape (num_samples, num_variables)
        # Each numpy array corresponds to data in different domains, and can have different sample sizes
        # The first numpy array, i.e., all_data[0], has to be the observatioanl data (with no interventions)
        method = Method(all_data, args.CI_type, args.alpha)
        params_est = method.run_algo()
    elif args.method_type == 'gies':
        params_est = gies(all_data, my_interv_targets)
    elif args.method_type == 'jci_gsp':
        params_est = jci_gsp(all_data, args.alpha)
    elif args.method_type == 'igsp':
        params_est = igsp(all_data, my_interv_targets, args.alpha)
    elif args.method_type == 'ut_igsp':
        params_est = ut_igsp(all_data, args.alpha)
    else:
        raise ValueError("Unknown method type.")
    
    print("method success!")

    # Save outputs
    if args.method_type in {'ours', 'jci_gsp', 'igsp', 'ut_igsp', 'gies'}:
        try:
            # print("hello!")
            results = count_dag_accuracy(input['dag'], params_est['dag'])
            print("results: ", results)
            # _logger.info("Results of causal discovery method: {}".format(results))

            basics = {'perturbation':args.perturbation_type,  'samples':args.samplenum, 'variable':args.nodenum, 'method':args.method_type, 'seed':args.seed}
            write_result(results, basics)
        except:
            pass
    # if args.method_type == 'ours':
    #     try:
    #         plot_graphs(interv_targets, selection_parents, dag_true,
    #                     params_est['pag_edges_from_observational'],
    #                     params_est['pag_edges_from_interventional'],
    #                     output_dir)
    #     except:
    #         pass
    # np.save('{}/params_est.npy'.format(output_dir), params_est)
    # np.save('{}/params_true.npy'.format(output_dir), params_true)
    # _logger.info("Finished saving training outputs at {}.".format(output_dir))


if __name__ == '__main__':
    main()