import logging
from logging import Logger
import os
import os.path as osp
import torch
import traceback
import matplotlib.pyplot as plt
import csv   

from util.train_utils import find_last_checkpoint_path

from omegaconf import DictConfig, OmegaConf
import hydra
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import seed_everything
from lightning.pytorch.tuner import Tuner

OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
logger: Logger = logging.getLogger(__name__)
torch.set_float32_matmul_precision('high')


def find_memory_per_batch_size(cfg: DictConfig, batch_size, data_module, find_max_batch_size=False):
    '''
    For a given batch size, function returns the batch size, 
    the cuda memory before allocation, and the max cuda memory allocation after 1 fit step.
    '''
    seed_everything(cfg.seed)
    print('########################### BATCH SIZE: ',batch_size)
    
    data_module.cfg.batch_size = batch_size
    # tensorboard
    tb_logger = TensorBoardLogger(save_dir=osp.expanduser(cfg.io.base_output_path), name=cfg.tag, version=cfg.io.version)

    # Pytorch Lightning moodule
    model = hydra.utils.instantiate(cfg.task, cfg)
        
    # Trainer
    trainer = Trainer(
        **cfg.trainer,
        logger=tb_logger,
    )
    
    if find_max_batch_size:
        tuner = Tuner(trainer)
        tuner.scale_batch_size(model, mode="binsearch")

    before_fit = torch.cuda.max_memory_allocated()*1e-9
    trainer.validate(model, datamodule=data_module)
    after_fit = torch.cuda.max_memory_allocated()*1e-9
    torch.cuda.reset_max_memory_allocated()
    
    batch_size = data_module.cfg.batch_size
    return before_fit, after_fit, batch_size

def loop_through_batches(cfg, data_module, batch_sizes):
    oom_error = False
    before_fit_list = []
    after_fit_list = []
    last_working_batch = 1
    successful_batches = []
    
    for b in batch_sizes:
        # try:
        before_fit, after_fit, _ = find_memory_per_batch_size(cfg, b, data_module)
        before_fit_list.append(before_fit)
        after_fit_list.append(after_fit)
        last_working_batch = b
        successful_batches.append(b)

        with open(cfg.fname, 'a') as f:
            writer = csv.writer(f)
            writer.writerow([cfg.tag, b, before_fit, after_fit])
        
    return before_fit_list, after_fit_list, successful_batches
    
@hydra.main(config_path="./config", config_name="defaults", version_base="1.1")
def run(cfg: DictConfig):
    print(f"PyTorch-Lightning Version: {pl.__version__}")
    batch_sizes = [2,4,8,16,32]#,128,256,512,1024,2048]
    print(batch_sizes)
    
    # DataLoader
    data_module = hydra.utils.instantiate(cfg.data_module)
    
    before_fit_list, after_fit_list, successful_batches = loop_through_batches(cfg, data_module, batch_sizes)
    
if __name__ == "__main__":
    os.environ["HYDRA_FULL_ERROR"] = os.environ.get("HYDRA_FULL_ERROR", "1")
    run()


