import argparse
import logging
import datetime
import pytz
import torch
from torch.utils.data import Subset, DataLoader
from torch_geometric.data import Batch
import numpy as np
import os
from typing import Any, Optional
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.types import STEP_OUTPUT

from open_biomed.utils.config import Config, Struct
from open_biomed.utils.callbacks import RecoverCallback, GradientClip
from models.critique_sbdd import CritiqueSBDD
from dataset.csd_by_sample import CSDBySample, CSDFromMolJo
from trainers.utils import setup_accounting, MovingAverage


class CritiqueSBDDCollator:
    def __call__(self, inputs):
        return {
            "molecule": Batch.from_data_list([inputs[i]["molecule"] for i in range(len(inputs))], follow_batch=["pos"]),
            "pocket": Batch.from_data_list([inputs[i]["pocket"] for i in range(len(inputs))], follow_batch=["pos"]),
            "labels": torch.stack([inputs[i]["labels"] for i in range(len(inputs))], dim=0),
        }
    
class TrainCritqueSBDD(pl.LightningModule):
    def __init__(self, config: Config):
        super(TrainCritqueSBDD, self).__init__()
        self.config = config
        self.train_cfg = config.train
        self.eval_cfg = config.eval
        self.model = CritiqueSBDD(config.model)
        self.loss_history = [MovingAverage(self.train_cfg.log_interval) for i in range(4)]
        state_dict = torch.load(config.model.pretrained_dir)["state_dict"]
        print(self.load_state_dict(state_dict, strict=False))
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        molecule, pocket, labels = batch["molecule"], batch["pocket"], batch["labels"]
        batch_size = molecule["pos_batch"].max() + 1
        t = torch.rand((batch_size, 1), device=molecule["pos"].device).index_select(0, molecule["pos_batch"])
        # t = torch.ones((batch_size, 1), device=molecule["pos"].device).index_select(0, molecule["pos_batch"])
        ret = self.model(pocket, molecule, labels, t)
        loss = torch.mean(ret["loss"])
        self.loss_history[0].add(loss.item())
        self.loss_history[1].add(torch.mean(ret["loss_affinity"]).item())
        self.loss_history[2].add(torch.mean(ret["loss_qed"]).item())
        self.loss_history[3].add(torch.mean(ret["loss_sa"]).item())
        self.log_dict(
            {
                "train/loss": self.loss_history[0].get_average(),
                "train/loss_affinity": self.loss_history[1].get_average(),
                "train/loss_qed": self.loss_history[2].get_average(),
                "train/loss_sa": self.loss_history[3].get_average(),
                "train/lr": self.trainer.optimizers[0].param_groups[0]["lr"],
            },
            on_step=True,
            prog_bar=True,
            batch_size=self.train_cfg.batch_size,
        )
        return loss
    
    def validation_step(self, batch, batch_idx):
        molecule, pocket, labels = batch["molecule"], batch["pocket"], batch["labels"]
        batch_size = molecule["pos_batch"].max() + 1
        val_loss = []
        for i in range(self.eval_cfg.num_sample_steps):
            t = torch.ones((batch_size, 1), device=molecule["pos"].device).index_select(0, molecule["pos_batch"]) * i / self.eval_cfg.num_sample_steps
            loss = torch.mean(self.model(pocket, molecule, labels, t)["loss"])
            val_loss.append(loss.item())
        self.log_dict(
            {
                f"val/loss": np.mean(val_loss),
            },
            on_step=True,
            prog_bar=True,
        )
        return val_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.train_cfg.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.train_cfg.max_iters)
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
        }

class EvalCritiqueSBDD(pl.Callback):
    def __init__(self, config: Config):
        super(EvalCritiqueSBDD, self).__init__()
        self.config = config
        self.loss_traj = [[] for i in range(self.config.eval.num_sample_steps)]

    def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self.mode = "val"

    def on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self.mode = "test"

    def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int = 0):
        for i in range(self.config.eval.num_sample_steps):
            self.loss_traj[i].append(outputs[i])

    def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        loss_traj = []
        for i in range(self.config.eval.num_sample_steps):
            loss_traj.append(np.mean(self.loss_traj[i]))
        if not os.path.exists(trainer.default_root_dir):
            os.makedirs(trainer.default_root_dir)
        torch.save(loss_traj, os.path.join(trainer.default_root_dir, f'loss_traj_{self.mode}_{pl_module.current_epoch}.pt'))
        self.loss_traj = [[] for i in range(self.config.eval.num_sample_steps)]

    def on_test_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0):
        self.on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

    def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self.on_validation_epoch_end(trainer, pl_module)



def get_logger(cfg):
    os.makedirs(cfg.accounting.dir, exist_ok=True)
    # TODO save code
    if cfg.wandb_resume_id is not None:
        wandb_logger = WandbLogger(
            id=cfg.wandb_resume_id,
            project="sbdd_critique",
            offline=cfg.no_wandb,
            save_dir=cfg.accounting.dir,
            resume='must',
        )
    else: # start a new run
        wandb_logger = WandbLogger(
            name=f"sbdd_critique"
            + f'_{datetime.datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d-%H:%M:%S")}',
            project="sbdd_critique",
            offline=cfg.no_wandb,
            save_dir=cfg.accounting.dir,
        )  # add wandb parameters
    return wandb_logger

def add_arguments(parser: argparse.ArgumentParser):
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--num_gpus", type=int, default=1)
    parser.add_argument("--empty_folder", action="store_true")
    parser.add_argument("--no_wandb", action='store_true')
    parser.add_argument("--wandb_resume_id", type=str, default=None)
    return parser

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser = add_arguments(parser)
    args = parser.parse_args()
    
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt=f'{datetime.datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d-%H:%M:%S")}',
        level=logging.INFO,
    )
    if args.debug:
        config = Config("./configs/train/critique_sbdd_debug.yaml")
    else:
        config = Config("./configs/train/critique_sbdd.yaml")
    config = setup_accounting(args, config, "critique_sbdd")
    print(config)
    dataset = CSDBySample(config.dataset, "critique_sbdd")
    indices = np.arange(0, config.dataset.train_cutoff), np.arange(config.dataset.train_cutoff, config.dataset.val_cutoff), np.arange(config.dataset.val_cutoff, len(dataset))
    pl.seed_everything(config.train.seed)

    # prepare datasets
    train_dataset = Subset(dataset, indices[0])
    val_dataset = Subset(dataset, indices[1])
    test_dataset = Subset(dataset, indices[2])
    logging.info(f"Train dataset size: {len(train_dataset)}")
    logging.info(f"Val dataset size: {len(val_dataset)}")
    logging.info(f"Test dataset size: {len(test_dataset)}")
    train_loader = DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=True, collate_fn=CritiqueSBDDCollator())
    val_loader = DataLoader(val_dataset, batch_size=config.eval.batch_size, shuffle=False, collate_fn=CritiqueSBDDCollator())
    test_loader = DataLoader(test_dataset, batch_size=config.eval.batch_size, shuffle=False, collate_fn=CritiqueSBDDCollator())
    config.train.max_iters = len(train_loader) * config.train.max_epochs

    # prepare model
    model = TrainCritqueSBDD(config)
    callbacks = [
        RecoverCallback(
            latest_ckpt=os.path.join(config.accounting.checkpoint_dir, "last.ckpt"), 
            recover_trigger_loss=1e7, 
            resume=args.resume,
        ),
        GradientClip(config.train.max_grad_norm),
        ModelCheckpoint(
            monitor="val/loss",
            every_n_epochs=config.train.ckpt_freq,
            dirpath=config.accounting.checkpoint_dir,
            filename="epoch{epoch:02d}-val_loss{val/loss:.4f}",
            save_top_k=5,
            mode="min",
            auto_insert_metric_name=False,
            save_last=True,
        ),
        EvalCritiqueSBDD(config),
    ]
    wandb_logger = get_logger(config)
    trainer = pl.Trainer(
        default_root_dir=config.accounting.dir,
        max_epochs=config.train.max_epochs,
        callbacks=callbacks,
        logger=wandb_logger,
        num_sanity_val_steps=0,
        accelerator='gpu',
        devices=args.num_gpus,
    )
    trainer.fit(model, train_loader, val_loader)
    trainer.test(model, test_loader)