### packages for visualization
from analysis.rdkit_functions import compute_molecular_metrics
from mini_moses.metrics.metrics import compute_intermediate_statistics
from metrics.property_metric import TaskModel

import torch
import torch.nn as nn
import numpy as np

import os
import csv
import time
import pickle

def result_to_csv(path, dict_data):
    file_exists = os.path.exists(path)
    log_name = dict_data.pop("log_name", None)
    if log_name is None:
        raise ValueError("The provided dictionary must contain a 'log_name' key.")
    field_names = ["log_name"] + list(dict_data.keys())
    dict_data["log_name"] = log_name
    with open(path, "a", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=field_names)
        if not file_exists:
            writer.writeheader()
        writer.writerow(dict_data)


class SamplingMolecularMetrics(nn.Module):
    def __init__(
        self,
        dataset_infos,
        train_smiles,
        reference_smiles,
        n_jobs=1,
        device="cpu",
        batch_size=512,
        train_y=None,
    ):
        super().__init__()
        self.task_name = dataset_infos.task
        self.dataset_infos = dataset_infos
        self.active_atoms = dataset_infos.active_atoms
        self.train_smiles = train_smiles
        self.train_y = train_y
        if reference_smiles is not None:
            print(
                f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
            )
            start_time = time.time()
            # Build absolute path to data/raw relative to the project base path
            base_path = getattr(self.dataset_infos, "base_path", os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
            stats_file_path = os.path.join(base_path, "data", "raw", f"{self.task_name}_stats_ref.pkl")
            if os.path.exists(stats_file_path):
                self.stat_ref = pickle.load(open(stats_file_path, "rb"))
            else:
                self.stat_ref = compute_intermediate_statistics(
                    reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
                )
                os.makedirs(os.path.dirname(stats_file_path), exist_ok=True)
                with open(stats_file_path, "wb") as f:
                    pickle.dump(self.stat_ref, f)
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(
                f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
            )
        else:
            self.stat_ref = None
    
        self.comput_config = {
            "n_jobs": n_jobs,
            "device": device,
            "batch_size": batch_size,
        }

        if self.task_name in ["cv-mu-alpha-homo-lumo-delta_eps", "f_osc"]:
            self.task_evaluator = {'meta_taskname': dataset_infos.task}
        elif self.task_name == "QM9":
            self.task_evaluator = None
        elif self.task_name == "MOSES":
            self.task_evaluator = None
        elif self.task_name == "MOSES-sas":
            self.task_evaluator = {'sas': None}
        elif self.task_name == "MOSES-logp":
            self.task_evaluator = {'logp': None}
        elif self.task_name == "MOSES-qed":
            self.task_evaluator = {'qed': None}
        elif self.task_name == "MOSES-molw":
            self.task_evaluator = {'molw': None}
        elif self.task_name == "ZINC":
            self.task_evaluator = None
        elif self.task_name == "ZINC-sas":
            self.task_evaluator = {'sas': None}
        elif self.task_name == "ZINC-logp":
            self.task_evaluator = {'logp': None}
        elif self.task_name == "ZINC-qed":
            self.task_evaluator = {'qed': None}
        elif self.task_name == "ZINC-molw":
            self.task_evaluator = {'molw': None}
        elif self.task_name == "sas-logp-qed":
            self.task_evaluator = {'sas': None, 'logp': None, 'qed': None}
        elif self.task_name == "sas-logp":
            self.task_evaluator = {'sas': None, 'logp': None}
        elif self.task_name == "logp-qed-sas":
            self.task_evaluator = {'logp': None, 'qed': None, 'sas': None}
        else:
            self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None}
        
        if self.task_evaluator is not None:
            for c, cur_task in enumerate(dataset_infos.task.split("-")[:]):
                # print('loading evaluator for task', cur_task)
                model_path = os.path.join(
                    dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
                )
                os.makedirs(os.path.dirname(model_path), exist_ok=True)
                if self.task_name in ["cv-mu-alpha-homo-lumo-delta_eps", "f_osc"]:
                    evaluator = TaskModel(model_path, cur_task, 
                                        y=self.train_y[:, c], 
                                        x_smiles=self.train_smiles)
                
                elif self.task_name == "logp-qed":
                    continue
                elif self.task_name == "MOSES-sas":
                    continue
                elif self.task_name == "MOSES-logp":
                    continue
                elif self.task_name == "MOSES-qed":
                    continue
                elif self.task_name == "MOSES-molw":
                    continue
                elif self.task_name == "ZINC-sas":
                    continue
                elif self.task_name == "ZINC-logp":
                    continue
                elif self.task_name == "ZINC-qed":
                    continue
                elif self.task_name == "ZINC-molw":
                    continue
                elif self.task_name == "sas-logp-qed":
                    continue
                elif self.task_name == "sas-logp":
                    continue
                elif self.task_name == "logp-qed-sas":
                    continue
                else:
                    evaluator = TaskModel(model_path, cur_task)
                self.task_evaluator[cur_task] = evaluator

    def forward(self, molecules, targets, name, current_epoch, val_counter, test=False, ref_smiles=None):
        if isinstance(targets, list):
            targets_cat = torch.cat(targets, dim=0)
            targets_np = targets_cat.detach().cpu().numpy()
        else:
            targets_np = targets.detach().cpu().numpy()

        unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics(
            molecules,
            targets_np,
            self.train_smiles,
            self.stat_ref,
            self.dataset_infos,
            self.task_evaluator,
            self.comput_config,
            test=test,
            ref_smiles=ref_smiles
        )
        
        total_errors = {}
        for task in all_metrics:
            if "acc" in task or "mae" in task:
                total_errors[task] = float(all_metrics[task])
            elif "FCD" in task:
                total_errors[task] = float(all_metrics[task])
            elif "diversity" in task:
                total_errors[task] = float(all_metrics[task])
            elif "sim" in task:
                total_errors[task] = float(all_metrics[task])
            elif "validity" in task:
                total_errors[task] = float(all_metrics[task])
            elif "novelty" in task:
                total_errors[task] = float(all_metrics[task])
            elif "uniqueness" in task:
                total_errors[task] = float(all_metrics[task])
            elif "nspdk" in task:
                total_errors[task] = float(all_metrics[task])
        
        if test:
            print("Number of reference smiles:", len(ref_smiles))
            print("Number of smiles:", len(all_smiles))
            file_name = "final_smiles.txt"
            with open(file_name, "w") as fp:
                if self.task_evaluator is not None:
                    all_tasks_name = list(self.task_evaluator.keys())
                    all_tasks_name = all_tasks_name.copy()
                    if 'meta_taskname' in all_tasks_name:
                        all_tasks_name.remove('meta_taskname')
                    if 'scs' in all_tasks_name:
                        all_tasks_name.remove('scs')
                else:
                    all_tasks_name = []

                all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
                all_tasks_str += ", reference"
                fp.write(all_tasks_str + "\n")
                for i, smiles in enumerate(all_smiles):
                    if targets_log is not None:
                        all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
                        all_result_str += f", {ref_smiles[i]}"
                        fp.write(all_result_str + "\n")
                    else:
                        fp.write("%s\n" % smiles)
                print("All smiles saved")
        else:
            result_path = os.path.join(os.getcwd(), f"graphs/{name}")
            os.makedirs(result_path, exist_ok=True)
            text_path = os.path.join(
                result_path,
                f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt",
            )
            textfile = open(text_path, "w")
            for smiles in unique_smiles:
                textfile.write(smiles + "\n")
            textfile.close()

        all_logs = all_metrics
        if test:
            all_logs["log_name"] = "test"
        else:
            all_logs["log_name"] = (
                "epoch" + str(current_epoch) + "_batch" + str(val_counter)
            )
        
        result_to_csv("output.csv", all_logs)
        
        return all_smiles, total_errors

    def reset(self):
        pass

if __name__ == "__main__":
    pass