### packages for visualization
from analysis.rdkit_functions import compute_molecular_metrics
import torch.nn as nn
import os
import csv
import time

import pandas as pd

def result_to_csv(path, dict_data):
    file_exists = os.path.exists(path)
    log_name = dict_data.pop("log_name", None)
    if log_name is None:
        raise ValueError("The provided dictionary must contain a 'log_name' key.")
    field_names = ["log_name"] + list(dict_data.keys())
    dict_data["log_name"] = log_name
    with open(path, "a", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=field_names)
        if not file_exists:
            writer.writeheader()
        writer.writerow(dict_data)

class SamplingMolecularMetrics(nn.Module):
    def __init__(
        self,
        dataset_infos,
        known_smiles,
        reference_smiles=None,
        n_jobs=1,
        device="cpu",
        batch_size=512,
    ):
        super().__init__()
        self.task_name = dataset_infos.task_name
        self.dataset_infos = dataset_infos
        self.known_smiles = known_smiles

        self.stat_ref = None # for reference_smiles
        self.comput_config = {
            "n_jobs": n_jobs,
            "device": device,
            "batch_size": batch_size,
        }

        mol_file_path = dataset_infos.mol_file_path
        source_smiles_list = pd.read_csv(mol_file_path, engine='pyarrow')['smiles'].tolist()
        self.source_smiles_list = source_smiles_list

        self.task_evaluator = {'meta_taskname': dataset_infos.task_name, 'sas': None, 'scs': None}


    def forward(self, molecules, targets, target_ids, target_to_similar_smiles, run_name, current_epoch, val_counter, test=False):
        target_similar_smiles = [target_to_similar_smiles[tid] for tid in target_ids]
        unique_smiles, all_smiles_dict, all_metrics, targets_log = compute_molecular_metrics(
            molecules,
            targets,
            target_similar_smiles,
            self.known_smiles,
            self.stat_ref,
            self.dataset_infos,
            self.task_evaluator,
            self.comput_config,
        )
        if test:
            text_path = "final_smiles.txt"
        else:
            result_path = os.path.join(os.getcwd(), f"graphs/{run_name}")
            os.makedirs(result_path, exist_ok=True)
            text_path = os.path.join(
                result_path,
                f"valid_e{current_epoch}_b{val_counter}.txt",
            )
        with open(text_path, "w") as fp:
            all_tasks_str = "generated,novel,target,target_id,target_source"
            fp.write(all_tasks_str + "\n")
            for i, generated in enumerate(all_smiles_dict['generation']):
                target = all_smiles_dict['target'][i]
                if i >= len(all_smiles_dict['valid_novel']):
                    novel_check = False
                else:
                    novel_check = all_smiles_dict['valid_novel'][i]
                target_id = target_ids[i]
                source_target = self.source_smiles_list[target_id]
                fp.write(f"{generated},{novel_check},{target},{target_id},{source_target}\n")
            print("All smiles saved")

        all_logs = all_metrics
        if test:
            all_logs["log_name"] = "test"
        else:
            all_logs["log_name"] = (
                "epoch" + str(current_epoch) + "_batch" + str(val_counter)
            )
        
        result_to_csv("output.csv", all_logs)
        return all_smiles_dict['generation'], all_smiles_dict['target']

    def reset(self):
        pass

if __name__ == "__main__":
    pass