import torch
import lightning as L
import yaml
import os
import argparse
from easydict import EasyDict
from torch_geometric.data import DataLoader
import torch.distributed as dist
import imageio
import wandb

from experiments.diffusion import OurDiffusion
from experiments.data_load.data_loader import get_datasets
from experiments.models import EGTN, BasicES, EGInterpolator, Embedding

from misc.visualize_mols import plot_ours


class Model(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.denoiser = self._init_denoiser()
        print(self.denoiser)
        self.diffusion = OurDiffusion(**config.diffusion)
        self.oom_batch_total = 0

    def _init_denoiser(self):
        if self.config.denoiser.type == 'egtn':
            denoiser = EGTN(**self.config.denoiser)
        elif self.config.denoiser.type == 'basic_es':
            denoiser = BasicES(**self.config.denoiser)
        elif self.config.denoiser.type == 'interpolator':
            denoiser = EGInterpolator(**self.config.denoiser)
        else:
            raise NotImplementedError()
        return denoiser

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.denoiser.parameters(), lr=self.config.optim.lr, weight_decay=self.config.optim.weight_decay
        )

    def on_train_epoch_start(self):
        self.denoiser.train()

    def _compute_loss(self, batch):
        batch = batch.to(self.device)
        x_start = batch.pos # , batch.edge_index
        model_kwargs = {
            "h": batch.x,
            "f": batch.x_features,
            "edge_index": batch.edge_index,
            "edge_attr": batch.edge_attr,
            "batch": batch.batch,
        }

        if self.config.dataset.type == 'trajectory' and self.config.denoiser.type == 'interpolator':
            conditioning = torch.zeros(self.config.dataset.expected_time_dim, dtype=torch.bool)
            if self.config.denoiser.conditioning != 'none':
                conditioning[0] = True
                if self.config.denoiser.conditioning == 'interpolation':
                    conditioning[-1] = True
            model_kwargs['cond_mask'] = conditioning
            model_kwargs['original_frames'] = batch.original_frames

        elif self.config.dataset.type == 'trajectory' and self.config.denoiser.type == 'egtn':
            conditioning = torch.zeros(self.config.dataset.expected_time_dim, dtype=torch.bool)
            if self.config.denoiser.conditioning != 'none':
                conditioning[0] = True
                if self.config.denoiser.conditioning == 'interpolation':
                    conditioning[-1] = True
                model_kwargs['cond_mask'] = conditioning
                model_kwargs['original_frames'] = batch.original_frames

        loss = self.diffusion.training_losses(
            model=self.denoiser,
            x_start=x_start,
            t=None,
            model_kwargs=model_kwargs,
        )["loss"]  # [B, L, D=2]

        return loss

        loss = self.diffusion.training_losses(
            model=self.denoiser,
            x_start=x_start,
            t=None,
            model_kwargs=model_kwargs,
        )["loss"]  # [B, L, D=2]

        return loss

    def training_step(self, batch, batch_idx):
        oom_flag = torch.tensor([0], device=self.device)

        try:
            loss = self._compute_loss(batch).mean()
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"[OOM] Rank {self.global_rank} had OOM at step {batch_idx}")
                torch.cuda.empty_cache()
                oom_flag[0] = 1  # This rank had OOM
            else:
                raise

            # --- Sync OOM flags across all ranks ---
            if self.trainer.world_size > 1:
                torch.distributed.all_reduce(oom_flag, op=torch.distributed.ReduceOp.SUM)

            if oom_flag.item() > 0:
                # One or more ranks had OOM → all return dummy loss
                dummy_loss = torch.tensor(0.0, device=self.device, requires_grad=True)
                self.log("train_loss", torch.nan, on_step=True, prog_bar=False)
                return dummy_loss

        # Otherwise, safe to return real loss
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True,
                batch_size=batch.batch.max().item() + 1)
        return loss

    def on_validation_epoch_start(self):
        self.denoiser.eval()

    def validation_step(self, batch, batch_idx):
        val_loss = self._compute_loss(batch).mean()
        self.log("val_loss", round(val_loss.item(), 6),
                 on_step=False, on_epoch=True, prog_bar=True,
                 batch_size=batch.batch.max().item() + 1)

    def validation_step_original(self, batch, batch_idx):
        val_loss = self._compute_loss(batch).mean()
        self.log("val_loss", round(val_loss.item(), 6),
                 on_step=False, on_epoch=True, prog_bar=True,
                 batch_size=batch.batch.max().item() + 1)
        # perform sampling
        # only do this on rank zero
        if batch_idx < self.config.eval.num_sample_batches:
            samples = self.diffusion.p_sample_loop(
                model=self.denoiser,
                shape=list(batch.pos.shape),
                model_kwargs={
                    "h": batch.x,
                    "f": batch.x_features,
                    "edge_index": batch.edge_index,
                    "edge_attr": batch.edge_attr,
                    "batch": batch.batch,
                },
                progress=True,
            )
            if self.trainer.global_rank == 0:
                # decode samples into molecules and visualize them,
                # into the path indexed by current epoch idx
                cur_samples_output_dir = os.path.join(
                    self.config.eval.samples_output_dir,
                    f"epoch_{self.current_epoch}",
                )
                os.makedirs(cur_samples_output_dir, exist_ok=True)
                image_files = []
                for sample_idx in range(batch.batch.max().item() + 1):
                    cur_position = samples[batch.batch == sample_idx]
                    cur_atom_number = batch.x[batch.batch == sample_idx]
                    cur_position = cur_position[..., 0]  # [N, 3]
                    save_file = os.path.join(
                        cur_samples_output_dir,
                        f"sample_{batch_idx}_{sample_idx}.png"
                    )
                    molecule = (
                        cur_atom_number.detach().cpu(),  # [N]
                        cur_position.detach().cpu(),  # [N, 3]
                    )
                    plot_ours(
                        molecule=molecule,
                        output_path=save_file,
                        dataset_name=self.config.dataset.dataset.lower(),
                        remove_h=False,
                        # remove_h=self.config.dataset.remove_hs,
                    )
                    image_files.append(save_file)
                print(f'Logging images to wandb on rank {self.trainer.global_rank}')
                self.logger.log_image(
                    key="validation_samples",
                    images=image_files,
                    caption=[f'Sample {i}' for i in range(len(image_files))],
                )


    def on_validation_epoch_end(self):
        pass


def main():
    # Get the parser
    print("Parsing the config")
    parser  = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='configs/temp.yaml')
    args = parser.parse_args()

    # Exract the Config Values
    config = yaml.safe_load(open(args.config, 'r'))
    config = EasyDict(config)
    print(config)
    logger = L.pytorch.loggers.WandbLogger(
        **config.wandb
    )
    
    # Initialize the Model  
    print("Initializing the Model")
    model = Model(config)
    if config.denoiser.pretrain_ckpt:
        print(f"Loading spatial layer weights from {config.denoiser.pretrain_ckpt}")
        ckpt = torch.load(config.denoiser.pretrain_ckpt)
        ckpt_state_dict = ckpt["state_dict"]  # assumes Lightning format

        # 1) Snapshot all denoiser params before loading
        original_params = {
            name: param.detach().clone()
            for name, param in model.named_parameters()
        }

        # 3) Load checkpoint non‑strictly and report mismatches
        incompat = model.load_state_dict(ckpt_state_dict, strict=False)
        print("Missing keys:   ", incompat.missing_keys)
        print("Unexpected keys:", incompat.unexpected_keys)

        # 4) Compare before vs after to find truly updated parameters
        updated_layers = []
        for name, param in model.named_parameters():
            if name in ckpt_state_dict:
                # use allclose to allow tiny fp differences
                if not torch.allclose(param.detach(), original_params[name], atol=1e-6):
                    updated_layers.append(name)

        # 5) Print results
        if updated_layers:
            print("\n✅ The following layers were updated from checkpoint:")
            for name in updated_layers:
                print(f"  • {name}")
        else:
            print("\n⚠️  No parameters changed (check key names or values).")

        # 5) Optionally freeze those same layers
        if config.denoiser.freeze_spatial:
            frozen_layers = []
            for name, param in model.named_parameters():
                if name in ckpt_state_dict:
                    param.requires_grad = False
                    frozen_layers.append(name)
            print("\n🧊 The following layers have been frozen from gradient updates:")
            for name in frozen_layers:
                print(f"  • {name}")

    # Initialize the Dataset
    print("Getting the Datasets")
    train_dataset, val_dataset = get_datasets(config) 
    # train_dataset.data = train_dataset.data[:100]  # Truncate to first 500 instances
    # val_dataset.data = val_dataset.data[:50]  # Truncate validation set proportionally
    print(f"Dataset sizes after truncation - Train: {len(train_dataset)}, Val: {len(val_dataset)}")

    # Initialize the DataLoader
    print("Initializing the DataLoaders")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        shuffle=True,
        num_workers=8
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config.eval.batch_size,
        shuffle=False,
        num_workers=8
    )

    # Initialize the Checkpoint Callback
    print("Initializing the Checkpoint Callback")
    checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
        filename='{epoch}-{step}',
        save_top_k=-1,  # Set -1 to save all epochs
        every_n_epochs=4,  # Save every 30 epochs
        save_last=False,  # Always save the last checkpoint to a file `last.ckpt`
        dirpath=f'checkpoints/{config.wandb.project}/{config.wandb.name}',
        monitor='val_loss',
    )

    print("Initializing the Trainer")
    trainer = L.Trainer(
        accumulate_grad_batches=config.train.accum_grad,
        max_epochs=config.train.max_epochs,
        accelerator='cuda',
        gradient_clip_val=1.0,
        log_every_n_steps=20,
        check_val_every_n_epoch=4,
        callbacks=[checkpoint_callback],
        logger=logger,
    )

    print("Training the Model")
    trainer.fit(
        model,
        train_dataloader,
        val_dataloader,
        ckpt_path=config.train.resume_from_checkpoint if config.train.resume_from_checkpoint else None
    )


if __name__ == '__main__':
    main()
