#!/usr/bin/env -S uv run

import functools
import logging
import os
from functools import partial

import dotenv
import tqdm

logging.basicConfig(format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", level=logging.INFO)
py_logger = logging.getLogger("jamun")

import torch

torch.set_float32_matmul_precision("high")

import e3nn
import e3tools.nn

e3nn.set_optimization_defaults(jit_script_fx=False)

import jamun
import jamun.data
import jamun.distributions
import jamun.model
import jamun.model.arch
import jamun.model.embedding

dotenv.load_dotenv("../.env", verbose=True)
JAMUN_DATA_PATH = os.getenv("JAMUN_DATA_PATH")

# Device.
device = torch.device("cuda:0")

datasets = {
    "train": jamun.data.parse_datasets_from_directory(
        root=f"{JAMUN_DATA_PATH}/timewarp/2AA-1-large/train/",
        traj_pattern="^(.*)-traj-arrays.npz",
        pdb_pattern="^(.*)-traj-state0.pdb",
        subsample=20,
        as_iterable=False,
        max_datasets=100,
        filter_codes=["AA"],
    )
}

datamodule = jamun.data.MDtrajDataModule(
    datasets=datasets,
    batch_size=32,
    num_workers=2,
)
datamodule.setup(None)

arch = functools.partial(
    jamun.model.arch.E3Conv,
    irreps_out="1x1e",
    irreps_hidden="120x0e + 32x1e",
    irreps_sh="1x0e + 1x1e",
    n_layers=5,
    radial_edge_embedder_factory=partial(
        jamun.model.embedding.RadialEdgeEmbedder, radial_edge_attr_dim=32, basis="gaussian", cutoff=True, max_radius=1.0
    ),
    bond_edge_embedder_factory=partial(jamun.model.embedding.BondEdgeEmbedder, bond_edge_attr_dim=32),
    atom_embedder_factory=partial(
        jamun.model.embedding.ResidueAtomEmbedder,
        atom_type_embedding_dim=8,
        atom_code_embedding_dim=8,
        residue_code_embedding_dim=32,
        residue_index_embedding_dim=8,
        use_residue_sequence_index=False,
        num_atom_types=20,
        max_sequence_length=10,
        num_atom_codes=10,
        num_residue_types=25,
    ),
    hidden_layer_factory=e3tools.nn.SeparableConvBlock,
    output_head_factory=partial(e3tools.nn.EquivariantMLP, irreps_hidden_list=["120x0e + 32x1e"]),
)
py_logger.info(f"Number of params: {sum(p.numel() for p in arch().parameters())}")

optim = functools.partial(torch.optim.Adam, lr=1e-2, weight_decay=0.0)

sigma_distribution = jamun.distributions.ConstantSigma(
    sigma=0.04,
)

denoiser = jamun.model.Denoiser(
    arch=arch,
    optim=optim,
    sigma_distribution=sigma_distribution,
    lr_scheduler_config=None,
    max_radius=1.0,
    average_squared_distance=0.332,
    add_fixed_noise=False,
    add_fixed_ones=False,
    align_noisy_input_during_training=True,
    align_noisy_input_during_evaluation=True,
    mean_center=True,
    mirror_augmentation_rate=0.0,
    use_torch_compile=True,
    torch_compile_kwargs=dict(
        fullgraph=True,
        dynamic=True,
        mode="default",
    ),
)

# Transfer to device.
denoiser = denoiser.to(device)
opt = denoiser.configure_optimizers()["optimizer"]

# Warmup.
n_warmup = 10

for i, batch in tqdm.tqdm(enumerate(datamodule.train_dataloader()), total=n_warmup, desc="Warmup"):
    if i == n_warmup:
        break

    batch = batch.to(device)
    out = denoiser.training_step(batch, i)
    loss = out["loss"]
    loss.backward()
    opt.step()
    opt.zero_grad()


# Actual training.
n_actual = 100
torch.cuda.cudart().cudaProfilerStart()

for i, batch in tqdm.tqdm(enumerate(datamodule.train_dataloader()), total=n_actual, desc="Training"):
    if i == n_actual:
        break

    batch = batch.to(device)
    torch.cuda.nvtx.range_push(f"iteration {i}")

    torch.cuda.nvtx.range_push("forward")
    out = denoiser.training_step(batch, i)
    torch.cuda.nvtx.range_pop()

    torch.cuda.nvtx.range_push("backward")
    loss = out["loss"]
    loss.backward()
    torch.cuda.nvtx.range_pop()

    torch.cuda.nvtx.range_push("step")
    opt.step()
    opt.zero_grad()
    torch.cuda.nvtx.range_pop()

    torch.cuda.nvtx.range_pop()

torch.cuda.cudart().cudaProfilerStop()
