import argparse, csv, time

from src.iperi import IPeri
from src.client import Client
import src.utils as utils
from src.dataset import Dataset
import ges

import networkx as nx
import numpy as np

from tqdm import tqdm

def parse_args():   
    parser = argparse.ArgumentParser(description="HI-Peri causal discovery")
    parser.add_argument('--n_clients', type=int, default=10, help='Number of clients')
    parser.add_argument('--n_samples_client', type=int, default=200, help='Number of samples')
    parser.add_argument('--n_variables', type=int, default=10, help='Number of variables')
    parser.add_argument('--vertical_split', type=bool, default=False, help='Use vertical_split data')
    parser.add_argument('--horizontal_split', type=bool, default=False, help='Use uneven sample split among clients')
    parser.add_argument('--data_type', type=str, default='obs')
    parser.add_argument('--cd_function', type=str, default='pc', help='Causal discovery function for clients (pc, lingam, ges)')
    parser.add_argument('--scoring_function', type=str, default='bic', help='Scoring function for clients (bic, bdeu)')
    parser.add_argument('--noise_distribution', type=str, default='normal', help='Type of noise (gaussian, uniform)')
    parser.add_argument('--seed', type=int, default=1846, help='Random seed')
    parser.add_argument('--linear', type=bool, default=False, help='Use linear causal discovery')
    parser.add_argument('--masked', type=bool, default=False, help='Use masked causal discovery, if False standard PERI')
    parser.add_argument('--max_iters', type=int, default=1, help='Maximum number of iterations for HI-Peri')
    parser.add_argument('--save', type=bool, default=False, help='Save generated data')
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    utils.set_determine(args.seed)

    if args.cd_function == 'lingam':
        args.noise_distribution = 'uniform'

    graph_types = 'erdos_renyi'
    if args.data_type == 'sachs':
        graph_types = 'sachs'
        args.n_variables = 11  # Sachs dataset has 11 variables
        args.n_clients = 6

    if args.data_type == 'causalchambers':
        graph_types = 'causalchambers'
        args.n_variables = 20  # CausalChambers lt dataset has 10 variables
        args.n_clients = 20

    dataset = Dataset(
        graph_type=graph_types,
        n_samples_client=args.n_samples_client,
        n_clients=args.n_clients,
        n_variables=args.n_variables,
        vertical_split=False,
        horizontal_split=args.horizontal_split,
        noise_distribution=args.noise_distribution,
        seed=args.seed,
    )

    
    datasets, cpdag, ucpdag, graphs = dataset.generate(data_type=args.data_type, save=True)

    graph = dataset.graph
    experiment_name = dataset.folder_name
    server_variables = graph.nodes if args.data_type != 'sachs' else list(range(args.n_variables))
    
    clients = []
    client_graphs = []
    error_client_graphs = []
    t0 = time.time()
    for i in tqdm(range(args.n_clients)):
        data_path = f'{experiment_name}/client_{i}.csv'
        clients.append(Client(
            name=f'client_{i}',
            data=data_path,
            cd_function=args.cd_function,
            scoring_function=args.scoring_function,
            masked=args.masked,
            linear=args.linear
            )
        )
        client_graphs.append(clients[-1].graph)
        ground_graph = cpdag if args.data_type != 'struct' else ges.utils.dag_to_cpdag(graphs[i])

        error_client_graphs.append(
            utils.f1_orientation(clients[-1].graph, ground_graph)
        )  
        
    iperi = IPeri(
        n_variables=len(graph.nodes),
        clients=clients
    )

    union_graph = utils.union_graph(nx.to_numpy_array(graph), client_graphs)
    server_graph, estimated_cpdag = iperi.fit(max_iters=args.max_iters)
    t1 = time.time()

    # print("Union graph: \n", union_graph)

    # print("SHD union graph: ", utils.shd(union_graph, nx.to_numpy_array(graph)))
    # print("F1 union graph:  ", utils.f1_orientation(union_graph, nx.to_numpy_array(graph)))

    # print("DAG vs CPDAG comparison:")
    # print("CPDAG: \n", cpdag)
    # print("SHD: ", utils.shd(nx.to_numpy_array(graph), cpdag))
    # print("SHD skeleton: ", utils.shd_skeleton(nx.to_numpy_array(graph), cpdag))
    # print("F1:  ", utils.f1_orientation(nx.to_numpy_array(graph), cpdag))
    # print("F1 skeleton:  ", utils.f1_skeleton(nx.to_numpy_array(graph), cpdag))

    # print("DAG vs U-CPDAG comparison:")
    # print("U-CPDAG: \n", ucpdag)
    # print("SHD: ", utils.shd(nx.to_numpy_array(graph), ucpdag))
    # print("SHD skeleton: ", utils.shd_skeleton(nx.to_numpy_array(graph), ucpdag))
    # print("F1:  ", utils.f1_orientation(nx.to_numpy_array(graph), ucpdag))
    # print("F1 skeleton:  ", utils.f1_skeleton(nx.to_numpy_array(graph), ucpdag))  

    # print("**** CPDAG ****")
    # print("Learned CPDAG: \n", estimated_cpdag)
    # print("True graph: \n", nx.to_numpy_array(graph))
    # print("SHD: ", utils.shd(estimated_cpdag, nx.to_numpy_array(graph)))
    # print("SHD skeleton: ", utils.shd_skeleton(estimated_cpdag, nx.to_numpy_array(graph)))
    # print("F1:  ", utils.f1_orientation(estimated_cpdag, nx.to_numpy_array(graph)))
    # print("F1 skeleton:  ", utils.f1_skeleton(estimated_cpdag, nx.to_numpy_array(graph)))
    # print("SHD CPDAG: ", utils.shd(estimated_cpdag, cpdag))
    # print("F1 CPDAG:  ", utils.f1_orientation(estimated_cpdag, cpdag))

    # print("**** U-CPDAG ****")
    # print("Learned graph: \n", server_graph)
    # print("True graph: \n", nx.to_numpy_array(graph))
    # print("SHD: ", utils.shd(server_graph, nx.to_numpy_array(graph)))
    # print("SHD skeleton: ", utils.shd_skeleton(server_graph, nx.to_numpy_array(graph)))
    # print("F1:  ", utils.f1_orientation(server_graph, nx.to_numpy_array(graph)))
    # print("F1 skeleton:  ", utils.f1_skeleton(server_graph, nx.to_numpy_array(graph)))
    # print("SHD U-CPDAG: ", utils.shd(server_graph, ucpdag))
    # print("F1 U-CPDAG:  ", utils.f1_orientation(server_graph, ucpdag))

    # print("Union graph vs U-CPDAG comparison:")
    # print("SHD: ", utils.shd(ucpdag, union_graph))
    # print("F1:  ", utils.f1_orientation(ucpdag, union_graph))
    # print("Union graph vs DAG")
    # print("SHD union graph: ", utils.shd(union_graph, nx.to_numpy_array(graph)))
    # print("F1 union graph:  ", utils.f1_orientation(union_graph, nx.to_numpy_array(graph)))

    # print("Error Client:", np.mean(error_client_graphs))

    if args.save:
        result = {
            "n_samples_client": args.n_samples_client,
            "n_clients": args.n_clients,
            "n_variables": args.n_variables,
            "vertical_split": False,
            "horizontal_split": args.horizontal_split,
            "data_type": args.data_type,
            "cd_function": args.cd_function,
            "scoring_function": args.scoring_function,
            "linear": args.linear,
            "masked": args.masked,
            "noise_distribution": args.noise_distribution,    
            "seed": args.seed,
            "shd": utils.shd(server_graph, nx.to_numpy_array(graph)),
            # "shd_skeleton": utils.shd_skeleton(server_graph, nx.to_numpy_array(graph)),
            "f1": utils.f1_orientation(server_graph, nx.to_numpy_array(graph)),
            # "f1_skeleton": utils.f1_skeleton(server_graph, nx.to_numpy_array(graph)),
            "shd_cpdag_est": utils.shd(estimated_cpdag, nx.to_numpy_array(graph)),
            "f1_cpda_est": utils.f1_orientation(estimated_cpdag, nx.to_numpy_array(graph)),
            "shd_union": utils.shd(union_graph, nx.to_numpy_array(graph)),
            "f1_union": utils.f1_orientation(union_graph, nx.to_numpy_array(graph)),
            "shd_ucpdag": utils.shd(server_graph, ucpdag),
            "f1_ucpdag": utils.f1_orientation(server_graph, ucpdag),
            "shd_cpdag": utils.shd(estimated_cpdag, cpdag),
            "f1_cpdag": utils.f1_orientation(estimated_cpdag, cpdag),
            "shd_ucpdag_cpdag": utils.shd(cpdag, ucpdag),
            "f1_ucpdag_cpdag": utils.f1_orientation(cpdag, ucpdag),
            "error_client": np.mean(error_client_graphs),
            "time": t1 - t0
        }

        with open(f'results.csv', mode='a', newline='') as file:
            writer = csv.DictWriter(file, fieldnames=result.keys())
            if file.tell() == 0:
                writer.writeheader()
            writer.writerow(result)