import os
from datetime import datetime
from pathlib import Path
import pytest
import torch
import shutil
import torch
import warnings
from wavesAI.data.language_modeling import NextStepPredictionData
from wavesAI.model.mamba_optimized_v2 import MambaSeq2Seq
from wavesAI.task.sequence_modeling import SequenceModeling
from lightning import Trainer
import logging
import socket
from datetime import datetime, timedelta

from lightning.pytorch.callbacks import RichProgressBar

from rich import pretty

TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

def trace_handler(prof: torch.profiler.profile):
   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   # Construct the trace file.
   prof.export_chrome_trace(f"{file_prefix}.json.gz")

   # Construct the memory timeline file.
   prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")


if __name__ == "__main__":
    try:
        # enable memory history, which will
        # add tracebacks and event history to snapshots
        torch.cuda.memory._record_memory_history()

        data_root = Path(os.environ["DATA_ROOT"])
        scratch_path = Path(os.environ["SCRATCH_PATH"])

        ## data
        data_module = NextStepPredictionData(data_dir="/datasets/ai/slim-pajama/SlimPajama-627B",
                            cache_dir=str(scratch_path / "huggingface/cache"),
                            dataset_name=str(data_root / "wikitext/wikitext-2-v1"),
                            tokenizer_dir=str(scratch_path / "huggingface/wikitext-wordlevel/tokenizer.json"),
                            max_length=2048,
                            batch_size=16,
                            num_workers=1)

        # with torch.profiler.profile(
        #         activities=[
        #             torch.profiler.ProfilerActivity.CPU,
        #             torch.profiler.ProfilerActivity.CUDA,
        #         ],
        #         schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
        #         record_shapes=True,
        #         profile_memory=True,
        #         with_stack=True,
        #         on_trace_ready=trace_handler,
        # ) as prof:

        ## model
        # net = MambaSeq2Seq(32, 32, 75001, 75001, flags="aouc",
        #              d_state=32, dt_rank=32, x_factor=1.0, x_bias_factor=1.0, constant_B=True, constant_C=True)
        net = MambaSeq2Seq(32, d_state=32, output_vocab_size=75001, layers="m|a|m", vocab_size=75001)

        model =  SequenceModeling(net,
                                  lambda x: torch.optim.AdamW(x, lr=1e-3, weight_decay=0.0),
                                  lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, T_max=5000),
                                  scheduler_interval="epoch", ignore_index=data_module.PAD_TOKEN)

        data_module.prepare_data()
        data_module.setup()

        trainer = Trainer(limit_train_batches=3, limit_val_batches=1, default_root_dir="test_logs_tmp", max_epochs=1,
                          detect_anomaly=True, callbacks=[RichProgressBar()])

        trainer.fit(model, datamodule=data_module)

        print("Fit completed")
    finally:

        # Construct the memory timeline HTML plot.
        # prof.export_memory_timeline("memory_timeline.html", device="cuda:0")

        torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
        if os.path.exists("test_logs_tmp"):
            shutil.rmtree("test_logs_tmp")
