"""
Based on: https://github.com/kuleshov-group/mdlm

"""

import os

import fsspec
import hydra
import lightning as L
import omegaconf
import rich.syntax
import rich.tree
import torch

import json
from datetime import datetime

import dataloader
import diffusion_with_images
import utils
import tokenizers

import matplotlib.pyplot as plt 
import numpy as np
import seaborn as sns
from PIL import Image
import torchvision.transforms as transforms

from cc3m_dataset import CC3MDataset
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import CSVLogger
import logging

import models 
import transformers

import torch.distributed as dist
from datetime import timedelta

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import destroy_process_group 
import itertools

from lightning.pytorch.strategies import DDPStrategy
from StepByStepSamplingTester import StepByStepSamplingTester

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Add Lightning-specific environment variables for SLURM
if 'SLURM_PROCID' in os.environ:
    os.environ['RANK'] = os.environ['SLURM_PROCID']
    os.environ['LOCAL_RANK'] = os.environ['SLURM_LOCALID']
    os.environ['WORLD_SIZE'] = os.environ['SLURM_NTASKS']

omegaconf.OmegaConf.register_new_resolver(
  'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
  'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
  'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
  'div_up', lambda x, y: (x + y - 1) // y)
  

def _train(config, csv_logger, logger, tokenizer):
  logger.info('Starting Training.')

  if (config.checkpointing.resume_from_ckpt
      and config.checkpointing.resume_ckpt_path is not None
      and utils.fsspec_exists(
        config.checkpointing.resume_ckpt_path)):
    ckpt_path = config.checkpointing.resume_ckpt_path
    logger.info(f"Resuming from checkpoint: {ckpt_path}")
  else:
    ckpt_path = None
    logger.info("No checkpoint found, training from scratch")

  # Lightning callbacks
  callbacks = []
  if 'callbacks' in config:
    for _, callback in config.callbacks.items():
      callbacks.append(hydra.utils.instantiate(callback))

  # Set the current device based on LOCAL_RANK before creating datasets
  if torch.cuda.is_available():
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    torch.cuda.set_device(local_rank)
    print(f"Process {os.environ.get('RANK', 0)}: Set CUDA device to {local_rank}")

  # Create datasets
  train_dataset = CC3MDataset(
    dataset_path_or_name= *, # train_path
    tokenizer=tokenizer,
    split="train",
    max_length=config.model.length,
    size=config.train_ds.img_size,
    augmentation=False,
    image_key="jpg",
    caption_key="txt",
    image_transform=None,
  )
  
  valid_dataset = CC3MDataset(
    dataset_path_or_name= *, # validation_path
    tokenizer=tokenizer,
    split="validation",
    max_length=config.model.length,
    size=config.val_ds.img_size,
    augmentation=False,
    image_key="jpg",
    caption_key="txt",
    image_transform=None,
  )

  # Create dataloaders - Lightning will handle distributed sampling automatically
  train_ds = DataLoader(
    train_dataset,
    batch_size=config.loader.batch_size,
    num_workers=0,
    pin_memory=config.loader.pin_memory,
    shuffle=True  
  )
  
  valid_ds = DataLoader(
    valid_dataset,
    batch_size=config.loader.batch_size,
    num_workers=0,
    pin_memory=config.loader.pin_memory,
    shuffle=False
  )
    
  utils.print_batch(train_ds, valid_ds, tokenizer)
  
  # Load the pretrained model
  diffusion = diffusion_with_images.Diffusion(config=config, tokenizer=tokenizer).to('cuda')
  model = utils.load_from_checkpoint(config=config, logger=logger, tokenizer=tokenizer, model=diffusion)

  # Proper DDP configuration for SLURM
  if 'SLURM_PROCID' in os.environ:
    # Use DDPStrategy with find_unused_parameters for complex models
    strategy = DDPStrategy(
      find_unused_parameters=True,  # Set to True if unused parameters
      gradient_as_bucket_view=True,  # Memory optimization
    )
    num_nodes = int(os.environ.get('SLURM_JOB_NUM_NODES', '1'))
    devices = int(os.environ.get('SLURM_NTASKS_PER_NODE', '4'))
    accelerator = 'gpu'
    
    print(f"SLURM detected: num_nodes={num_nodes}, devices={devices}")
  else:
    # Fallback for non-SLURM environments
    strategy = 'ddp_find_unused_parameters_true'
    num_nodes = 1
    devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
    accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
    
    print(f"Non-SLURM: num_nodes={num_nodes}, devices={devices}")

  # Create trainer with proper DDP configuration
  trainer = L.Trainer(
    default_root_dir=os.getcwd(),
    callbacks=callbacks,
    strategy=strategy,
    logger=csv_logger,
    gradient_clip_val=1.0,
    devices=devices,
    accelerator=accelerator,
    num_nodes=num_nodes,
    val_check_interval=0.25,
    precision="16-mixed",
    sync_batchnorm=True,  # Synchronize batch norm across GPUs
    enable_checkpointing=True,
  )
  
  # Print trainer info
  print(f"Trainer created with:")
  print(f"  - Strategy: {trainer.strategy}")
  print(f"  - Devices: {trainer.num_devices}")
  print(f"  - Nodes: {trainer.num_nodes}")
  print(f"  - Global rank: {trainer.global_rank}")
  print(f"  - Local rank: {trainer.local_rank}")
  print(f"  - World size: {trainer.world_size}")
    
  trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)

def _test(config, logger, tokenizer):

    # Load the pretrained model
    diffusion = diffusion_with_images.Diffusion(config=config, tokenizer=tokenizer).to('cuda')
    model = utils.load_from_checkpoint(config=config, logger=logger, tokenizer=tokenizer, model=diffusion)
    model.eval()
    model.to('cuda')
    
    # Create tester
    tester = StepByStepSamplingTester(model, tokenizer)
    
    try:
        
        test_dataset = CC3MDataset(
	    dataset_path_or_name= *, # validation or test path
	    tokenizer=tokenizer,
	    split="validation",
	    max_length=config.model.length,
	    size=config.val_ds.img_size,
	    augmentation=False,
	    image_key="jpg",
	    caption_key="txt",
	    image_transform=None,
	    )
	
        for i in range(5):	
	        # Get one sample
	        test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
	
	        sample_batch = next(iter(test_dataloader))        
	        real_image = sample_batch['images']
	        ground_truth = sample_batch['caption'][0] if 'caption' in sample_batch else None        
	        results = tester.test_with_step_visualization(
	        real_image, ground_truth, num_steps=2000, stop_at_eos=True
	        )
        
    except Exception as e:
        print(f"Error: {e}")
        return None

@hydra.main(version_base=None, config_path='configs', config_name='config_cc3m')
def main(config):
  L.seed_everything(config.seed)
  utils.print_config(config, resolve=True, save_cfg=True)
  
  
  csv_logger = CSVLogger(save_dir="logs", name="mdlm")
  logger = logging.getLogger(__name__)
  logger.setLevel(logging.INFO)
  
  tokenizer = dataloader.get_tokenizer(config)

  if config.mode == 'train':
    _train(config, csv_logger, logger, tokenizer)
  else:
    _test(config, logger, tokenizer)

if __name__ == '__main__':
  main()
