import argparse
import os.path
import pickle
from parsers.parser import Parser
from utils.plot import plot_graphs_list_eps
import os
import io
import random
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from rdkit import Chem
from rdkit.Chem import Draw, AllChem, rdDepictor
from rdkit.Chem.Draw import MolDrawOptions

def visualization_results(training_graphs, sampled_graphs, dataset):
    save_dir = f'visualization/{dataset}'

    with open(training_graphs, 'rb') as f:
        training_graphs_list = pickle.load(f)

    plot_graphs_list_eps(graphs=training_graphs_list, title=f'{dataset}-train.eps', max_num=16,
                     save_dir=save_dir)


    with open(sampled_graphs, 'rb') as f:
        sample_graph_list = pickle.load(f)

    plot_graphs_list_eps(graphs=sample_graph_list, title=f'{dataset}-sample.eps', max_num=16,
                             save_dir=save_dir)


def read_smiles_from_file(filename):
    smiles_list = []
    with open(filename, 'r') as f:
        for line in f:
            smiles = line.strip()
            if smiles:
                smiles_list.append(smiles)
    return smiles_list


def create_beautiful_molecule_grid(smiles_list, n_cols=4, n_rows=5,
                                   output_filename='beautiful_molecules.png',
                                   mol_size=(250, 200),
                                   figure_size=(6, 4),
                                   random_selection=True):
    if random_selection and len(smiles_list) > n_rows * n_cols:
        selected_smiles = random.sample(smiles_list, n_rows * n_cols)
    else:
        selected_smiles = smiles_list[:n_rows * n_cols]

    valid_mols = []

    for smi in selected_smiles:
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is not None:
                rdDepictor.Compute2DCoords(mol)
                valid_mols.append(mol)
        except Exception as e:
            print(f"Error: {smi}, : {str(e)}")

    n_mols = len(valid_mols)
    if n_mols == 0:
        print("There are no valid smiles")
        return

    actual_rows = min(n_rows, (n_mols + n_cols - 1) // n_cols)

    fig = plt.figure(figsize=figure_size, facecolor='white')
    plt.subplots_adjust(top=0.85)

    drawer_opts = Draw.DrawingOptions()
    drawer_opts.bondLineWidth = 2
    drawer_opts.atomLabelFontSize = 14
    drawer_opts.includeAtomNumbers = False

    drawer_opts.elemDict = {
        'O': (0.94, 0.25, 0.15),
        'N': (0.0, 0.47, 0.8),
        'S': (0.9, 0.75, 0.1),
        'F': (0.2, 0.7, 0.3),
        'Cl': (0.0, 0.8, 0.5),
        'Br': (0.6, 0.1, 0.6),
        'I': (0.47, 0.0, 0.73)
    }

    for idx, mol in enumerate(valid_mols):
        if idx >= n_rows * n_cols:
            break

        ax = plt.subplot(actual_rows, n_cols, idx + 1)

        img = Draw.MolToImage(mol, size=mol_size, options=drawer_opts, kekulize=True, fitImage=True)

        ax.imshow(img)
        ax.axis('off')

    plt.tight_layout()
    plt.savefig(output_filename, dpi=300, bbox_inches='tight', format='eps', facecolor=fig.get_facecolor())
    print(f"Saving at: {output_filename}")
    plt.close()


def visualize_molecules_from_file(filename, output_filename, n_cols=4, n_rows=5, random_selection=True):
    if os.path.exists(filename):
        smiles_list = read_smiles_from_file(filename)
        create_beautiful_molecule_grid(smiles_list,
                                       n_cols=n_cols,
                                       n_rows=n_rows,
                                       output_filename=output_filename,
                                       random_selection=random_selection)
    else:
        print(f"File is unavailable: {filename}")


if __name__ == '__main__':
    # dataset_name = 'planar'
    # training_file = '../data/planar.pkl'
    # testing_file = '../samples/pkl/planar/test/Sep19-12-02-22-sample.pkl'

    # dataset_name = 'sbm'
    # training_file = '../data/sbm.pkl'
    # testing_file = '../samples/pkl/sbm/sbm_train/Sep19-15-49-22-sample.pkl'

    dataset_name = 'tree'
    training_file = '../data/tree.pkl'
    testing_file = '../samples/pkl/tree/test/Sep22-12-57-55-sample.pkl'

    visualization_results(training_file, testing_file, dataset_name)

    # qm9_file = "xxx.txt"
    # zinc250k_file = "xxx.txt"
    #
    # if os.path.exists(qm9_file):
    #     visualize_molecules_from_file(qm9_file, output_filename='qm9_samples.eps', random_selection=True, n_cols=6, n_rows=4)
    # else:
    #     print(f"File is unavailable: {qm9_file}")
    #
    # if os.path.exists(zinc250k_file):
    #     visualize_molecules_from_file(zinc250k_file, output_filename='zinc250k_samples.eps', n_cols=6, n_rows=4, random_selection=True)
    # else:
    #     print(f"File is unavailable: {zinc250k_file}")