import pickle
from attr import field
import torch
import csv

import ruamel.yaml
import itertools

from utils.logger import set_log
from utils.loader import load_ckpt, load_data, load_seed, load_eval_settings
from evaluation.stats import eval_graph_list
from utils.mol_utils import load_smiles, canonicalize_smiles, mols_to_nx, smiles_to_mols
from projop.project_bisection import satisfies
import networkx as nx
from parsers.config import get_config, get_constraint_config
from moses.metrics.metrics import get_all_metrics

from multiprocessing import Process
import os

import time

from evaluation.filter_constr import filtermap_constrained_graphs, filtermap_constrained_smiles
import argparse
import sys

TOTAL_MOLS = 10000
MIN_CONSTR_P = 0
seed = 42
device = 'cpu'
nprocs = 5

parser = argparse.ArgumentParser()
parser.add_argument("--eq", action='store_true')
args = parser.parse_args(sys.argv[1:])

fname = "results_bisecteq" if args.eq else "results_bisect5"

def evaluate (log_folder_name, log_name, configt, config, constr_config, device='cpu'):
    _, test_graph_list = load_data(configt, get_graph_list=True)

    if not os.path.exists(f'./samples/pkl/{log_folder_name}/{log_name}.pkl'):
        return {}

    with open(f'./samples/pkl/{log_folder_name}/{log_name}.pkl', 'rb') as f:
        gen_graph_list = pickle.load(f)

    methods, kernels = load_eval_settings(config.data.data)
    test_constr = filtermap_constrained_graphs (test_graph_list, configt, constr_config=constr_config)
    # print (test_constr.sum(), MIN_CONSTR_P * len(test_constr)/100)
    if test_constr.sum() < MIN_CONSTR_P * len(test_constr)/100:
        return {}
    test_graph_list = [graph for constr, graph in zip(test_constr, test_graph_list) if constr]
    result_dict = eval_graph_list(test_graph_list, gen_graph_list, methods=methods, kernels=kernels)
    adjs = torch.zeros(len(gen_graph_list), configt.data.max_node_num, configt.data.max_node_num)
    for i, G in enumerate(gen_graph_list):
        nG = G.number_of_nodes()
        adjs[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())
    xs = torch.zeros (len(gen_graph_list), configt.data.max_node_num, configt.data.max_feat_num)
    constr_val = satisfies(xs, adjs, constr_config).sum().item()/len(adjs)
    result_dict['constr_val'] = constr_val
    return result_dict

def evaluate_mol (log_dir, log_name, configt, config, constr_config, device='cpu'):
    if not os.path.exists(f"{log_dir}/{log_name}.txt"):
        print (f"{log_dir}/{log_name} not found")
        return {}
    load_seed(config.sample.seed)
    train_smiles, test_smiles = load_smiles(configt.data.data, file_ext='_can')
    # train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)
    # obtain generated smiles

    gen_smiles = []
    with open (f"{log_dir}/{log_name}.txt", 'r') as f:
        for line in f:
            gen_smiles.append(line[:-1])

    gen_mols = smiles_to_mols (gen_smiles)
    num_mols = len(gen_mols)
    gen_graph_list = mols_to_nx (gen_mols)
    
    # metrics
    with open(f'data/{configt.data.data.lower()}_test_nx.pkl', 'rb') as f:
        test_graph_list = pickle.load(f)

    test_constr = filtermap_constrained_smiles (test_smiles, configt, constr_config=constr_config)
    # print (test_constr.sum(), MIN_CONSTR_P * len(test_constr)/100)
    if test_constr.sum() < MIN_CONSTR_P * len(test_constr)/100:
        return {}
    test_smiles = [smiles for constr, smiles in zip(test_constr, test_smiles) if constr]
    test_graph_list = [graph for constr, graph in zip(test_constr, test_graph_list) if constr]
    
    scores = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device, n_jobs=8, 
                             test=test_smiles, train=train_smiles)
    scores_nspdk = eval_graph_list(test_graph_list, gen_graph_list, methods=['nspdk'])['nspdk']
    result_dict = {}
    metrics = ['valid', f'unique@{len(gen_smiles)}', 'FCD/Test', 'Novelty']
    metric_names = ['valid', 'unique', 'fcd', 'novelty']
    for metric_name, metric in zip(metric_names, metrics):
        if metric_name == 'valid':
            result_dict[metric_name] = scores[metric] * num_mols / TOTAL_MOLS
        result_dict[metric_name] = scores[metric]
    result_dict['nspdk'] = scores_nspdk
    result_dict['num_mols'] = num_mols
    adjs = torch.zeros(len(gen_graph_list), configt.data.max_node_num, configt.data.max_node_num)
    for i, G in enumerate(gen_graph_list):
        nG = G.number_of_nodes()
        adjs[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G, weight='label').todense())
    xs = torch.zeros (num_mols, configt.data.max_node_num, configt.data.max_feat_num)
    atom_id_map = {'C': 0, 'N': 1, 'O': 2, 'F': 3, 'P': 4, 'S': 5, 'Cl': 6, 'Br': 7, 'I': 8}
    for i, G in enumerate(gen_graph_list):
        xs[i, torch.arange(len(G.nodes)), [atom_id_map[x['label']] for x in G.nodes().values()]] = 1
    constr_val = satisfies(xs, adjs, constr_config).sum().item()/len(adjs)
    result_dict['constr_val'] = constr_val
    return result_dict

def run_setting(dataset_constraints, methods, method_params):
    with open (f"all_results/all_{fname}.csv", "a") as wf, open (f"all_results/all_mol_{fname}.csv", "a") as mwf:
        writer = csv.DictWriter(wf, fieldnames=['dataset', 'constraint', 'param', 'method', 'method_param',
                                                'degree', 'cluster', 'orbit', 'spectral', 'constr_val'])
        mwriter = csv.DictWriter(mwf, fieldnames=['dataset', 'constraint', 'param', 'method', 'method_param',
                                                'valid', 'unique', 'fcd', 'novelty', 'nspdk', 'num_mols', 
                                                'constr_val'])
        for dataset, constraint in dataset_constraints:
            config = get_config (f'sample_{dataset}', seed)
            constr_config = get_constraint_config (f'{constraint}_{dataset}' if dataset in ['qm9', 'zinc250k'] else f'{constraint}')
            constr_config["eq"] = args.eq
            yaml = ruamel.yaml.YAML()
            constraint_vals = yaml.load(open(f"config/constraints/master_{constraint}.yaml", 'r'))
            all_params = constraint_vals[dataset]
            for method, method_param in zip(methods, method_params):
                for param_vals in all_params: 
                    print (dataset, constraint, method, method_param, param_vals)
                    constr_config["schedule"] = {'gamma': method, 'params': method_param}

                    results_dict = {'dataset': dataset, 'constraint': constraint, 'method': method,
                                    'method_param': '|'.join(map(str, method_param))}
                    
                    if constraint in ['valency', 'atomCount']:
                        constr_config['params'] = [param_vals]
                        if dataset == 'qm9':
                            atoms = ["C", "N", "O", "F"]
                            results_dict['param'] = ''.join([x + str(y) for x, y in zip(atoms, param_vals)])
                        else:
                            atoms = ["C", "N", "O", "F", 'P', 'S', 'Cl', 'Br', 'I']
                            results_dict['param'] = ''.join([x + str(y) for x, y in zip(atoms, param_vals)])
                    elif constraint in ['nedges', 'nedgesl2']:
                        constr_config['params'] = ['zeros', param_vals]
                        results_dict['param'] = param_vals
                    elif constraint.startswith('molWeight'):
                        constr_config['params'][1] = param_vals
                        results_dict['param'] = param_vals
                    elif constraint.startswith ("prop"):
                        constr_config['params'][1] = param_vals
                        results_dict['param'] = param_vals
                    else:
                        constr_config["params"] = [param_vals]
                        results_dict['param'] = param_vals
                    
                    # -------- Load checkpoint --------
                    ckpt_dict = load_ckpt(config, device)
                    configt = ckpt_dict['config']
                    # 
                    log_folder_name, log_dir, _ = set_log(configt, 
                                                        constraint=((constr_config.constraint + ("-eq" if constr_config.eq else "")) \
                                                                     if method != 'none' else 'None'),
                                                        is_train=False)
                    log_name = f"{config.ckpt}-{constr_config.method.op}-{constr_config.method.gamma if method != 'none' else 0}"
                    log_name += f"-bisect" if method != 'none' else ''
                    log_name += f"-{constr_config.schedule.gamma}{','.join(map(str, constr_config.schedule.params))}"
                    log_name += f"-{constr_config.burnin}"
                    log_name += f"-{constr_config.add_diff_step}"
                    param_vals = map (str, constr_config.params if method != 'none' else [])
                    log_name += f"-{','.join(param_vals)}-{constr_config.rounding}"
                    log_name = log_name.replace(".", "p")

                    try:
                        if dataset in ['qm9', 'zinc250k']:
                            out_dict = evaluate_mol (log_dir, log_name, configt, config, constr_config, device=device)
                        else:
                            out_dict = evaluate (log_folder_name, log_name, configt, config, constr_config, device=device)
                    except:
                        continue
                    if len(out_dict) == 0:
                        continue

                    for k in out_dict:
                        results_dict[k] = out_dict[k]
                    
                    if dataset in ['qm9', 'zinc250k']:
                        mwriter.writerow(results_dict)
                    else:
                        writer.writerow(results_dict)
                    wf.flush()
                    mwf.flush()
                    

if __name__ == '__main__':
    datasets = ['community_small', 'ego_small', 'grid', 'enzymes']
    constraints = ["nedges", "ntriangles", "maxDegree"]
    
    mol_datasets = ['qm9', 'zinc250k']
    mol_constraints = ['molWeight']

    dataset_constraints = list(itertools.product(datasets, constraints)) + \
                          list(itertools.product (mol_datasets, mol_constraints)) + \
                          [("qm9", "prop-gap"), ("qm9", "prop-homo"), ("qm9", "prop-lumo")] + \
                          [("zinc250k", "prop-logP"), ("zinc250k", "prop-qed"), ("zinc250k", "prop-SAS")]
    # dataset_constraints = [('community_small', 'nedges')]
    dataset_constraints = list(itertools.product (mol_datasets, mol_constraints))
    # dataset_constraints = [('enzymes', 'nedges'), ('enzymes', 'ntriangles'), ('enzymes', 'maxDegree')]
    # dataset_constraints += [("qm9", "propIn-homo"), ("qm9", "propIn-mu")]

    poly_pows = [1, 5] #[1, 3, 5]
    gamma_inits = [0, 1e-1] #[0, 1e-4, 1e-1]
    poly_pows = [0, 0.1, 0.2, 0.5]
    gamma_inits = [0] 
    method_params = [[x, y] for x, y in itertools.product (gamma_inits, poly_pows)]
    methods = ['poly'] * (len(gamma_inits)*len(poly_pows))
    method_params += [[1]] #[[0], [1]]
    methods += ['fixed'] #, 'fixed'] # the 0 got kinda wrong as it clamped to valid values at each stage. 
    method_params += [[0]]
    methods += ['none'] 

    # method_params = [[0]]
    # methods = ['none'] 

    # with open (f"all_results/all_{fname}.csv", "w+") as wf, open (f"all_results/all_mol_{fname}.csv", "w+") as mwf:
    #     writer = csv.DictWriter(wf, fieldnames=['dataset', 'constraint', 'param', 'method', 'method_param',
    #                                             'degree', 'cluster', 'orbit', 'spectral', 'constr_val'])
    #     mwriter = csv.DictWriter(mwf, fieldnames=['dataset', 'constraint', 'param', 'method', 'method_param',
    #                                             'valid', 'unique', 'fcd', 'novelty', 'nspdk', 'num_mols', 
    #                                             'constr_val'])
    #     writer.writeheader()
    #     mwriter.writeheader()
    # run_setting (dataset_constraints, methods, method_params)

    n_total = len(dataset_constraints)
    proc_nsettings = int (n_total/nprocs)
    procs = []
    for i in range(nprocs):
        i_off = i*proc_nsettings
        if (i == nprocs - 1):
            idc_settings = dataset_constraints[i*proc_nsettings:]
        else:
            idc_settings = dataset_constraints[i*proc_nsettings: (i+1)*proc_nsettings]
        iproc = Process(target=run_setting, 
                        args=(idc_settings, methods, method_params))
        iproc.Daemon = True
        iproc.start()
        procs.append(iproc)

    for iproc in procs:
        iproc.join()
    