import argparse
import logging
import datetime
from typing import Dict, Any
import pytz
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from tqdm import tqdm
import torch
from torch.utils.data import Subset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np

from open_biomed.utils.config import Config
from open_biomed.datasets.molecule_protein_dataset import CrossDocked
from open_biomed.models.molecule.molcraft import MolCRAFT
from open_biomed.utils.callbacks import RecoverCallback, GradientClip
from open_biomed.tasks.aidd_tasks.structure_based_drug_design import StructureBasedDrugDesignEvaluationCallback
from open_biomed.utils.featurizer import EnsembleFeaturizer
from open_biomed.utils.collator import EnsembleCollator, PygCollator

from dataset.csd_by_sample import CSDBySample
from models.molcraft import MolCRAFTWithCFG
from trainers.utils import setup_accounting, MovingAverage, get_logger, ValidationCallbackWithInterval


class SBDDModel(pl.LightningModule):
    def __init__(self, config: Config):
        super(SBDDModel, self).__init__()
        self.config = config
        self.train_cfg = config.train
        self.eval_cfg = config.eval
        self.model = MolCRAFT(config.model)
        self.loss_history = {}

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        molecule, pocket = batch["molecule"], batch["pocket"]
        batch_size = molecule["pos_batch"].max() + 1
        t = torch.rand((batch_size, 1), device=molecule["pos"].device).index_select(0, molecule["pos_batch"])
        molecule["mu_pos"], _ = self.model.continuous_var_bayesian_update(t, molecule["pos"])
        molecule["theta_h"] = self.model.discrete_var_bayesian_update(t, molecule["atom_feature"], self.config.model.ligand_atom_feature_dim)
        ret = self.model.forward_with_t(pocket, molecule, t)
        loss = torch.mean(ret["loss"])
        for key, value in ret.items():
            if key not in self.loss_history:
                self.loss_history[key] = MovingAverage(self.config.train.log_interval)
            self.loss_history[key].add(torch.mean(value).item())
        log_dict = {f"train/{key}": value.get_average() for key, value in self.loss_history.items()}
        self.log_dict(log_dict, on_step=True, on_epoch=True, prog_bar=True, batch_size=self.config.train.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        molecule, pocket = batch["molecule"], batch["pocket"]
        batch_size = molecule["pos_batch"].max() + 1
        val_loss = []
        for i in range(self.eval_cfg.num_sample_steps):
            t = torch.ones((batch_size, 1), device=molecule["pos"].device).index_select(0, molecule["pos_batch"]) * i / self.eval_cfg.num_sample_steps
            molecule["mu_pos"], _ = self.model.continuous_var_bayesian_update(t, molecule["pos"])
            molecule["theta_h"] = self.model.discrete_var_bayesian_update(t, molecule["atom_feature"], self.config.model.ligand_atom_feature_dim)
            loss = torch.mean(self.model.forward_with_t(pocket, molecule, t)["loss"])
            val_loss.append(loss.item())
        self.log_dict(
            {
                f"val/loss": np.mean(val_loss),
            },
            on_step=True,
            prog_bar=True,
            batch_size=self.config.eval.batch_size,
        )
        return np.mean(val_loss)
    
    def test_step(self, batch, batch_idx):
        return self.model.predict_structure_based_drug_design(batch["pocket"])
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.model.parameters(), 
            lr=self.config.train.lr
        )
        schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            factor=self.config.train.schedular.factor, 
            patience=self.config.train.schedular.patience, 
            min_lr=self.config.train.schedular.min_lr
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": schedular,
                "monitor": "val/loss",
                "interval": "step",
                "frequency": self.config.train.val_freq,
            },
        }

class SBDDModelWithCFG(SBDDModel):
    def __init__(self, config: Config):
        super(SBDDModelWithCFG, self).__init__(config)
        self.model = MolCRAFTWithCFG(config.model)
        self.loss_history = {}

    def training_step(self, batch, batch_idx):
        molecule, pocket, classifier_input = batch["molecule"], batch["pocket"], batch["classifier_input"]
        batch_size = molecule["pos_batch"].max() + 1
        t = torch.rand((batch_size, 1), device=molecule["pos"].device).index_select(0, molecule["pos_batch"])
        molecule["mu_pos"], _ = self.model.continuous_var_bayesian_update(t, molecule["pos"])
        molecule["theta_h"] = self.model.discrete_var_bayesian_update(t, molecule["atom_feature"], self.config.model.ligand_atom_feature_dim)
        ret = self.model.forward_with_t(pocket, molecule, classifier_input, t)
        loss = torch.mean(ret["loss"])
        for key, value in ret.items():
            if key not in self.loss_history:
                self.loss_history[key] = MovingAverage(self.config.train.log_interval)
            self.loss_history[key].add(torch.mean(value).item())
        self.log_dict(
            {f"train/{key}": value.get_average() for key, value in self.loss_history.items()},
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            batch_size=self.config.train.batch_size,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        molecule, pocket, classifier_input = batch["molecule"], batch["pocket"], batch["classifier_input"]
        batch_size = molecule["pos_batch"].max() + 1
        val_loss = []
        for i in range(self.eval_cfg.num_sample_steps):
            t = torch.ones((batch_size, 1), device=molecule["pos"].device).index_select(0, molecule["pos_batch"]) * i / self.eval_cfg.num_sample_steps
            molecule["mu_pos"], _ = self.model.continuous_var_bayesian_update(t, molecule["pos"])
            molecule["theta_h"] = self.model.discrete_var_bayesian_update(t, molecule["atom_feature"], self.config.model.ligand_atom_feature_dim)
            loss = torch.mean(self.model.forward_with_t(pocket, molecule, classifier_input, t)["loss"])
            val_loss.append(loss.item())
        self.log_dict(
            {
                f"val/loss": np.mean(val_loss),
            },
            on_step=True,
            prog_bar=True,
            batch_size=self.config.eval.batch_size,
        )
        return np.mean(val_loss)
    
    def test_step(self, batch, batch_idx):
        batch_size = batch["pocket"]["pos_batch"].max() + 1
        if getattr(self.config.model, "discrete", False):
            classifier_input = torch.ones(batch_size, 3).long().to(batch["pocket"]["pos"].device) * 8
        else:
            classifier_input = torch.ones(batch_size, 3).to(batch["pocket"]["pos"].device)
        return self.model.predict_structure_based_drug_design(batch["pocket"], classifier_input)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="./configs/train/molcraft_cfg.yaml")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--num_gpus", type=int, default=1)
    parser.add_argument("--empty_folder", action="store_true")
    parser.add_argument("--no_wandb", action='store_true')
    parser.add_argument("--wandb_resume_id", type=str, default=None)
    parser.add_argument("--test_only", action="store_true")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt=f'{datetime.datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d-%H:%M:%S")}',
        level=logging.INFO,
    )
    exp_name = args.config.split("/")[-1].split(".")[0]
    config = Config(args.config)
    config = setup_accounting(args, config, exp_name)
    config.exp_name = exp_name
    print(config)

    if "rft" in exp_name:
        logging.info("Current experiment is RFT")
        model = SBDDModel(config)
        featurizer = EnsembleFeaturizer({
            "pocket": model.model.featurizers["pocket"],
            "molecule": model.model.featurizers["molecule"],
        })
        collator = EnsembleCollator({
            "pocket": model.model.collators["pocket"],
            "molecule": model.model.collators["molecule"],
        })
    elif "cfg" in exp_name:
        logging.info("Current experiment is CFG")
        model = SBDDModelWithCFG(config)
        featurizer = EnsembleFeaturizer({
            "pocket": model.model.featurizers["pocket"],
            "molecule": model.model.featurizers["molecule"],
            "classifier_input": lambda x: x,
        })
        collator = EnsembleCollator({
            "pocket": model.model.collators["pocket"],
            "molecule": model.model.collators["molecule"],
            "classifier_input": lambda x: torch.stack(x, dim=0),
        })
    else:
        raise ValueError(f"Invalid experiment name: {exp_name}")
    dataset = CSDBySample(config.dataset, config.exp_name)
    train_cutoff = getattr(config.dataset, "train_cutoff", len(dataset) - 200)
    if args.debug:
        train_dataset = Subset(dataset, range(1000))
        val_dataset = Subset(dataset, range(len(dataset) - 100, len(dataset)))
    else:
        train_dataset = Subset(dataset, range(train_cutoff))
        val_dataset = Subset(dataset, range(train_cutoff, len(dataset)))
    
    csd_dataset = CrossDocked(
        Config.from_dict(
            path="./data", 
            debug=True,
            pocket_only=True,
        ),
        featurizer=featurizer,
    )
    _, _, test_dataset = csd_dataset.split()

    train_loader = DataLoader(train_dataset, batch_size=config.train.batch_size, collate_fn=collator, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.eval.batch_size, collate_fn=collator, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.eval.batch_size, collate_fn=collator, shuffle=False)

    wandb_logger = get_logger(config)
    trainer = Trainer(
        default_root_dir=config.accounting.dir,
        max_epochs=config.train.max_epochs,
        logger=wandb_logger,
        num_sanity_val_steps=0,
        accelerator='gpu',
        devices=args.num_gpus,
        callbacks=[
            RecoverCallback(
                latest_ckpt=os.path.join(config.accounting.checkpoint_dir, "last.ckpt"), 
                recover_trigger_loss=1e7, 
                resume=args.resume,
            ),
            ModelCheckpoint(
                monitor="val/loss",
                every_n_epochs=config.train.ckpt_freq,
                dirpath=config.accounting.checkpoint_dir,
                filename="epoch{epoch:02d}-val_loss{val/loss:.4f}",
                save_top_k=5,
            ),
            GradientClip(
                max_grad_norm=config.train.max_grad_norm,
            ),
            ValidationCallbackWithInterval(val_freq=config.train.val_freq),
            StructureBasedDrugDesignEvaluationCallback()
        ],
    )
    if args.test_only:
        trainer.test(model, test_loader)
    else:
        trainer.fit(model, train_loader, val_loader)
        trainer.test(model, test_loader)