### packages for visualization
from analysis.rdkit_functions import compute_molecular_metrics
from metrics.property_metric import TaskModel

import torch
import torch.nn as nn

import time
from mini_moses.metrics.metrics import compute_intermediate_statistics

import os
import csv

def result_to_csv(path, dict_data):
    file_exists = os.path.exists(path)
    log_name = dict_data.pop("log_name", None)
    if log_name is None:
        raise ValueError("The provided dictionary must contain a 'log_name' key.")
    field_names = ["log_name"] + list(dict_data.keys())
    dict_data["log_name"] = log_name
    with open(path, "a", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=field_names)
        if not file_exists:
            writer.writeheader()
        writer.writerow(dict_data)


class SamplingMolecularMetrics(nn.Module):
    def __init__(
        self,
        dataset_infos,
        train_smiles,
        reference_smiles,
        n_jobs=1,
        device="cpu",
        batch_size=512,
    ):
        super().__init__()
        self.task_name = dataset_infos.task
        self.dataset_infos = dataset_infos
        self.active_atoms = dataset_infos.active_atoms
        self.train_smiles = train_smiles

        if reference_smiles is not None:
            print(
                f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
            )
            start_time = time.time()
            self.stat_ref = compute_intermediate_statistics(
                reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
            )
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(
                f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
            )
        else:
            self.stat_ref = None
    
        self.comput_config = {
            "n_jobs": n_jobs,
            "device": device,
            "batch_size": batch_size,
        }

        self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None}
        for cur_task in dataset_infos.task.split("-")[:]:
            print('loading evaluator for task', cur_task)
            model_path = os.path.join(
                dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
            )
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            evaluator = TaskModel(model_path, cur_task)
            self.task_evaluator[cur_task] = evaluator

    def forward(self, molecules, targets, name, current_epoch, val_counter, test=False):
        if isinstance(targets, list):
            targets_cat = torch.cat(targets, dim=0)
            targets_np = targets_cat.detach().cpu().numpy()
        else:
            targets_np = targets.detach().cpu().numpy()

        unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics(
            molecules,
            targets_np,
            self.train_smiles,
            self.stat_ref,
            self.dataset_infos,
            self.task_evaluator,
            self.comput_config,
        )

        if test:
            file_name = "final_smiles.txt"
            with open(file_name, "w") as fp:
                all_tasks_name = list(self.task_evaluator.keys())
                all_tasks_name = all_tasks_name.copy()
                if 'meta_taskname' in all_tasks_name:
                    all_tasks_name.remove('meta_taskname')
                if 'scs' in all_tasks_name:
                    all_tasks_name.remove('scs')

                all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
                fp.write(all_tasks_str + "\n")
                for i, smiles in enumerate(all_smiles):
                    if targets_log is not None:
                        all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
                        fp.write(all_result_str + "\n")
                    else:
                        fp.write("%s\n" % smiles)
                print("All smiles saved")
        else:
            result_path = os.path.join(os.getcwd(), f"graphs/{name}")
            os.makedirs(result_path, exist_ok=True)
            text_path = os.path.join(
                result_path,
                f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt",
            )
            textfile = open(text_path, "w")
            for smiles in unique_smiles:
                textfile.write(smiles + "\n")
            textfile.close()

        all_logs = all_metrics
        if test:
            all_logs["log_name"] = "test"
        else:
            all_logs["log_name"] = (
                "epoch" + str(current_epoch) + "_batch" + str(val_counter)
            )
        
        result_to_csv("output.csv", all_logs)
        return all_smiles

    def reset(self):
        pass

if __name__ == "__main__":
    pass