import argparse
import torch
import os
import datetime
import pytz
import numpy as np
from tqdm import tqdm
from open_biomed.utils.config import Struct
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.types import STEP_OUTPUT
from typing import Any

class MovingAverage:
    def __init__(self, size) -> None:
        self.size = size
        self.values = []

    def add(self, value):
        if len(self.values) == self.size:
            self.values.pop(0)
        self.values.append(value)

    def get_average(self):
        if len(self.values) == 0:
            return torch.tensor([0.0])
        else:
            return torch.tensor(self.values).mean()

def setup_accounting(args: argparse.Namespace, config: Struct, task: str):
    # now = datetime.datetime.now(pytz.timezone("Asia/Shanghai"))
    # timestamp = now.strftime("%Y_%m_%d_%H_%M")
    # accounting_dir = os.path.join("./lightning_logs/critique_sbdd", timestamp)
    accounting_dir = f"./lightning_logs/{task}"
    if args.empty_folder:
        os.system(f"rm -r {accounting_dir}")
    if not os.path.exists(accounting_dir):
        os.makedirs(accounting_dir, exist_ok=True)
    dump_config_path = os.path.join(accounting_dir, "config.yaml")
    checkpoint_dir = os.path.join(accounting_dir, "checkpoints")
    val_output_dir = os.path.join(accounting_dir, "val_outputs")
    test_output_dir = os.path.join(accounting_dir, "test_outputs")
    if not os.path.exists(val_output_dir):
        os.makedirs(val_output_dir, exist_ok=True)
    if not os.path.exists(test_output_dir):
        os.makedirs(test_output_dir, exist_ok=True)
    accounting_cfg = {
        "dir": accounting_dir,
        "dump_config_path": dump_config_path,
        "checkpoint_dir": checkpoint_dir,
        "val_output_dir": val_output_dir,
        "test_output_dir": test_output_dir,
    }
    config.accounting = Struct(**accounting_cfg)
    config.save2yaml(dump_config_path)
    config.no_wandb = args.no_wandb
    config.wandb_resume_id = args.wandb_resume_id
    return config

def get_logger(cfg):
    os.makedirs(cfg.accounting.dir, exist_ok=True)
    # TODO save code
    if cfg.wandb_resume_id is not None:
        wandb_logger = WandbLogger(
            id=cfg.wandb_resume_id,
            project=cfg.exp_name,
            offline=cfg.no_wandb,
            save_dir=cfg.accounting.dir,
            resume='must',
        )
    else: # start a new run
        wandb_logger = WandbLogger(
            name=f"{cfg.exp_name}"
            + f'_{datetime.datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d-%H:%M:%S")}',
            project=cfg.exp_name,
            offline=cfg.no_wandb,
            save_dir=cfg.accounting.dir,
        )  # add wandb parameters
    return wandb_logger

class ValidationCallbackWithInterval(pl.Callback):
    def __init__(self, val_freq: int):
        self.val_freq = val_freq

    def on_train_batch_end(
        self, 
        trainer: Trainer, 
        pl_module: pl.LightningModule, 
        outputs: STEP_OUTPUT, 
        batch: Any, 
        batch_idx: int
    ) -> None:
        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)

        all_losses = []
        if trainer.global_step % self.val_freq == 0:
            pl_module.model.eval()
            with torch.no_grad():
                for val_batch in tqdm(trainer.val_dataloaders, desc="Validating"):
                    val_batch = pl_module.transfer_batch_to_device(val_batch, pl_module.device, 0)
                    loss = pl_module.validation_step(val_batch, batch_idx)
                    all_losses.append(loss)
            pl_module.model.train()
            print("Validation loss at step", trainer.global_step, ":", np.mean(all_losses))
