from dataclasses import dataclass
from email.generator import Generator
from pickletools import int4
from torch_geometric.data import Data, Batch
from torch import Tensor, LongTensor
from typing import Dict, List, Tuple, Any, Optional
from argparse import Namespace
import os.path as path
from datetime import datetime
from guacamol.distribution_learning_benchmark import DistributionLearningBenchmarkResult
from tensorboardX import SummaryWriter

@dataclass
class train_data:
    """ The training data of each step. """
    graph: Data
    query_atom: int
    cyclize_cand: List[int]
    label: Tuple[int, int]

@dataclass
class mol_train_data:
    """ The preproposed training data of one molecule. """
    mol_graph: Data
    props: List[float]
    start_label: int
    train_data_list: List[train_data]
    motif_list: List[int]

@dataclass
class batch_train_data:
    """ The preproposed training datao of a batch of molecules. """
    batch_smiles: List[str]
    batch_mols_graphs: Batch
    batch_props: Tensor
    batch_start_labels: LongTensor
    motifs_list: LongTensor
    batch_train_graphs: Batch
    mol_idx: LongTensor
    graph_idx: LongTensor
    query_idx: LongTensor
    cyclize_cand_idx: LongTensor
    motif_conns_idx: LongTensor
    labels: LongTensor

    def cuda(self):
        self.batch_mols_graphs = self.batch_mols_graphs.cuda()
        self.batch_props, self.batch_start_labels = self.batch_props.cuda(), self.batch_start_labels.cuda()
        self.batch_train_graphs = self.batch_train_graphs.cuda()
        self.mol_idx, self.graph_idx, self.query_idx = self.mol_idx.cuda(), self.graph_idx.cuda(), self.query_idx.cuda()
        self.cyclize_cand_idx, self.motif_conns_idx = self.cyclize_cand_idx.cuda(), self.motif_conns_idx.cuda()
        self.labels = self.labels.cuda()
        return self

@dataclass
class ModelParams:
    """ Hyperparameters of the model. """
    atom_embed_size: List[int]
    edge_embed_size: int
    motif_embed_size: int
    hidden_size: int
    latent_size: int
    depth: int
    motif_depth: int
    dropout: int
    vocab_processed_path: str
    train_file: str
    num_props: int
    greedy: bool
    beam_top: int
    temperature: float

    def __repr__(self) -> str:
        return f"""
        Model Parameters:
        atom_embed_size         |       {self.atom_embed_size}
        edge_embed_size         |       {self.edge_embed_size}
        motif_embed_size        |       {self.motif_embed_size}
        hidden_size             |       {self.hidden_size}
        latent_size             |       {self.latent_size}
        depth                   |       {self.depth}
        motif_depth             |       {self.motif_depth}
        dropout                 |       {self.dropout}
        num_props               |       {self.num_props}
        greedy                  |       {self.greedy}
        beam_top                |       {self.beam_top}
        temperature             |       {self.temperature}
        """
    @staticmethod
    def from_arguments(args: Namespace) -> "ModelParams":
        return ModelParams(
            atom_embed_size = args.atom_embed_size,
            edge_embed_size = args.edge_embed_size,
            motif_embed_size = args.motif_embed_size,
            hidden_size = args.hidden_size,
            latent_size = args.latent_size,
            depth = args.depth,
            motif_depth = args.motif_depth,
            dropout = args.dropout,
            vocab_processed_path = path.join(args.data_dir, args.preprocess_dir, args.vocab_processed_path),
            train_file = path.join(args.data_dir, args.train_file),
            num_props = 4,
            greedy = args.greedy,
            beam_top = args.beam_top,
            temperature=args.temperature
        )

@dataclass
class TrainingParams:
    """ Hyperparameters for model training. """
    lr: float
    lr_anneal_iter: int
    lr_anneal_rate: float
    grad_clip_norm: float
    epoch: int
    beta_schedule_mode: str
    beta_warmup: int
    beta_min: float
    beta_max: float
    beta_anneal_period: int
    beta_num_cycles: int
    prop_weight: float

    def __repr__(self) -> str:
        return f"""
        Training Parameters:
        lr                      |       {self.lr}
        lr_anneal_iter          |       {self.lr_anneal_iter}
        lr_anneal_rate          |       {self.lr_anneal_rate}
        grad_clip_norm          |       {self.grad_clip_norm}
        epoch                   |       {self.epoch}
        beta_schedule_mode      |       {self.beta_schedule_mode}
        beta_warmup             |       {self.beta_warmup}
        beta_min                |       {self.beta_min}
        beta_max                |       {self.beta_max}
        beta_anneal_period      |       {self.beta_anneal_period}
        beta_num_cycles         |       {self.beta_num_cycles}
        prop_weight             |       {self.prop_weight}
        """

    @staticmethod
    def from_arguments(args: Namespace) -> "TrainingParams":
        return TrainingParams(
            lr = args.lr,
            lr_anneal_iter = args.lr_anneal_iter,
            lr_anneal_rate = args.lr_anneal_rate,
            grad_clip_norm = args.grad_clip_norm,
            epoch = args.epoch,
            beta_schedule_mode = args.beta_schedule_mode,
            beta_warmup = args.beta_warmup,
            beta_min = args.beta_min,
            beta_max = args.beta_max,
            beta_anneal_period = args.beta_anneal_period,
            beta_num_cycles = args.beta_num_cycles,
            prop_weight = args.prop_weight,
        )

@dataclass
class PathTool:
    data_dir: str
    output_dir: str
    preprocess_dir: str
    log_dir: str
    model_save_dir: str
    model_save_path: str
    motifs_embed_save_path: str
    best_model_save_path: str
    best_motifs_embed_save_path: str
    train_file: str
    valid_file: str
    operation_path: str
    vocab_path: str
    mols_pkl_path: str
    train_processed_path: str
    valid_processed_path: str
    vocab_processed_path: str
    generate_path: str
    json_output_path: str
    operation_learning_log_path: str
    vocab_construct_log_path: str
    job_name: str
    tensorboard_dir: str
    load_model_path: Optional[str]=None

    @staticmethod
    def from_arguments(args: Namespace) -> "PathTool":
        if args.load_model_path is not None:
            load_model_path = path.join(args.data_dir, args.load_model_path) if args.load_from_data else path.join(args.output_dir, args.load_model_path)
        else:
            load_model_path = None
        date_str = datetime.now().strftime("%m-%d")
        time_str = datetime.now().strftime("%m-%d-%H:%M:%S")
        job_name = time_str + "-" + args.job_name
        return PathTool(
            data_dir = args.data_dir,
            output_dir = args.output_dir,
            preprocess_dir = path.join(args.data_dir, args.preprocess_dir),
            log_dir = path.join(args.output_dir, args.log_dir),
            model_save_dir = path.join(args.output_dir, args.model_save_dir),
            model_save_path = path.join(args.output_dir, args.model_save_dir, "model.ckpt.log"),
            motifs_embed_save_path = path.join(args.output_dir, args.model_save_dir, "motifs_embed.ckpt.log"),
            best_model_save_path = path.join(args.output_dir, args.model_save_dir, "model.ckpt.best"),
            best_motifs_embed_save_path = path.join(args.output_dir, args.model_save_dir, "motifs_embed.ckpt.best"),
            train_file = path.join(args.data_dir, args.train_file),
            valid_file = path.join(args.data_dir, args.valid_file),
            operation_path = path.join(args.data_dir, args.preprocess_dir, args.operation_path),
            vocab_path = path.join(args.data_dir, args.preprocess_dir, args.vocab_path),
            mols_pkl_path = path.join(args.data_dir, args.preprocess_dir, args.mols_pkl_path),
            train_processed_path = path.join(args.data_dir, args.preprocess_dir, args.train_processed_path),
            valid_processed_path = path.join(args.data_dir, args.preprocess_dir, args.valid_processed_path),
            vocab_processed_path = path.join(args.data_dir, args.preprocess_dir, args.vocab_processed_path),
            generate_path = path.join(args.output_dir, args.generated_path),
            json_output_path = path.join(args.output_dir, args.json_output_path),
            operation_learning_log_path = path.join(args.output_dir, args.log_dir, args.operation_learning_log_path),
            vocab_construct_log_path = path.join(args.output_dir, args.log_dir, args.vocab_construct_log_path),
            job_name = args.job_name,
            tensorboard_dir = path.join(args.tensorboard_dir, date_str, time_str, job_name),
            load_model_path = load_model_path,
        )

@dataclass
class Decoder_Output:
    decoder_loss: Any = None
    start_loss: Any = None
    query_loss: Any = None
    tart_acc: Any = None
    start_topk_acc: Any = None
    query_acc: Any = None
    query_topk_acc: Any = None


@dataclass
class BenchmarkResults:
    validity: DistributionLearningBenchmarkResult
    uniqueness: DistributionLearningBenchmarkResult
    novelty: DistributionLearningBenchmarkResult
    kl_div: DistributionLearningBenchmarkResult
    fcd: DistributionLearningBenchmarkResult

    def __repr__(self):
        return f"""
        ==============================================================
        | Metrics | Validity | Uniqueness | Novelty | KL Div |  FCD  |
        --------------------------------------------------------------
        | Scores  |  {self.validity.score:.3f}   |   {self.uniqueness.score:.3f}    |  {self.novelty.score:.3f}  | {self.kl_div.score:.3f}  | {self.fcd.score:.3f} |
        ==============================================================
        """

@dataclass
class VAE_Output:
    total_loss: Tensor = None
    kl_div: Tensor = None
    decoder_loss: Tensor = None
    start_loss: Tensor = None
    query_loss: Tensor = None
    pred_loss: Tensor = None
    start_acc: float = None
    start_topk_acc: float = None
    query_acc: float = None
    query_topk_acc: float = None

    def cpu(self) -> None:
        for _, value in vars(self).items():
            value.cpu()

    def print_results(self, total_step: int, lr:float, beta: float) -> None:
        print(f"[Step {total_step:5d}] | Loss. KL: {self.kl_div:3.3f}, decoder_loss: {self.decoder_loss:3.3f}, pred_loss: {self.pred_loss:2.5f} \
| Start_acc. top1: {self.start_acc: .3f}, top10: {self.start_topk_acc:.3f} | Query_acc. top1: {self.query_acc:.3f}, top10: {self.query_topk_acc:.3f} \
| Params. lr: {lr:.6f}, beta: {beta:.6f}.")
    
    @staticmethod
    def log_epoch_results(epoch: int, epoch_loss_list: List["VAE_Output"], dev_loss_list: List["VAE_Output"], benchmark_results: BenchmarkResults=None):
        num_batches = len(epoch_loss_list)
        epoch_mean_decoder_loss = sum([loss.decoder_loss for loss in epoch_loss_list]) / num_batches
        epoch_mean_kl_div = sum([loss.kl_div for loss in epoch_loss_list]) / num_batches
        epoch_mean_start_acc = sum([loss.start_acc for loss in epoch_loss_list]) / num_batches
        epoch_mean_start_topk_acc = sum([loss.start_topk_acc for loss in epoch_loss_list]) / num_batches
        epoch_mean_query_acc = sum([loss.query_acc for loss in epoch_loss_list]) / num_batches
        epoch_mean_query_topk_acc = sum([loss.query_topk_acc for loss in epoch_loss_list]) / num_batches

        num_batches = len(dev_loss_list)
        dev_mean_total_loss = sum([loss.total_loss for loss in dev_loss_list]) / num_batches
        dev_mean_decoder_loss = sum([loss.decoder_loss for loss in dev_loss_list]) / num_batches
        dev_mean_kl_div = sum([loss.kl_div for loss in dev_loss_list]) / num_batches
        dev_mean_start_acc = sum([loss.start_acc for loss in dev_loss_list]) / num_batches
        dev_mean_start_topk_acc = sum([loss.start_topk_acc for loss in dev_loss_list]) / num_batches
        dev_mean_query_acc = sum([loss.query_acc for loss in dev_loss_list]) / num_batches
        dev_mean_query_topk_acc = sum([loss.query_topk_acc for loss in dev_loss_list]) / num_batches

        print(f"""
        ==========================================================
        Epoch {epoch}
        ----------------------------------------------------------
        Training:                   
        epoch_mean_decoder_loss     |   {epoch_mean_decoder_loss:.6f}
        epoch_mean_kl_div           |   {epoch_mean_kl_div:.6f}
        epoch_mean_start_acc        |   {epoch_mean_start_acc:.6f}
        epoch_mean_start_topk_acc   |   {epoch_mean_start_topk_acc:.6f}
        epoch_mean_query_acc        |   {epoch_mean_query_acc:.6f}
        epoch_mean_query_topk_acc   |   {epoch_mean_query_topk_acc:.6f}
        ----------------------------------------------------------
        Dev:
        dev_mean_total_loss         |   {dev_mean_total_loss:.6f}
        dev_mean_decoder_loss       |   {dev_mean_decoder_loss:.6f}
        dev_mean_kl_div             |   {dev_mean_kl_div:.6f}
        dev_mean_start_acc          |   {dev_mean_start_acc:.6f}
        dev_mean_start_topk_acc     |   {dev_mean_start_topk_acc:.6f}
        dev_mean_query_acc          |   {dev_mean_query_acc:.6f}
        dev_mean_query_topk_acc     |   {dev_mean_query_topk_acc:.6f}
        ============================================================
        """)
        """
        ----------------------------------------------------------
        Benchmark:
        Validity                    |   {benchmark_results.validity.score:.3f}
        Uniqueness                  |   {benchmark_results.uniqueness.score:.3f}
        Novelty                     |   {benchmark_results.novelty.score:.3f}
        KL Divergence               |   {benchmark_results.kl_div.score:.3f}
        FCD                         |   {benchmark_results.fcd.score:.3f}
        """
        return dev_mean_total_loss

    def log_tb_results(self, prefix: str, total_step: int, tb:SummaryWriter, beta, lr) -> None:
        if prefix == "Train":
            tb.add_scalar(f"Loss/Total_Loss", self.total_loss, total_step)
            tb.add_scalar(f"Loss/Decoder_loss", self.decoder_loss, total_step)
            tb.add_scalar(f"Loss/KL_div", self.kl_div, total_step)
            tb.add_scalar(f"Loss/Start_loss", self.start_loss, total_step)
            tb.add_scalar(f"Loss/Query_loss", self.query_loss, total_step)
            tb.add_scalar(f"Loss/Prop_pred_loss", self.pred_loss, total_step)

            tb.add_scalar("Hyperparameters/beta", beta, total_step)
            tb.add_scalar("Hyperparameters/lr", lr, total_step)
        
        tb.add_scalar(f"{prefix}_Accuracy/Start_acc", self.start_acc, total_step)
        tb.add_scalar(f"{prefix}_Accuracy/Start_top10_acc", self.start_topk_acc, total_step)
        tb.add_scalar(f"{prefix}_Accuracy/Query_acc", self.query_acc, total_step)
        tb.add_scalar(f"{prefix}_Accuracy/Query_top10_acc", self.query_topk_acc, total_step)