import os
import json

# Set tokenizer parallelism to false to avoid warnings in multiprocessing
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import fsspec
import hydra
import lightning as L
import omegaconf
import rich.syntax
import rich.tree
import torch
from tqdm import tqdm
import numpy as np
import random
from collections import defaultdict
import multiprocessing as mp

import dataloader
import diffusion
import utils
from safetensors.torch import load_file

import mauve

# import setproctitle
# setproctitle.setproctitle("python main.py")



omegaconf.OmegaConf.register_new_resolver(
  'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
  'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
  'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
  'div_up', lambda x, y: (x + y - 1) // y)


def _lazy_get_spacy_tokenizer():
  """Return a spaCy tokenizer if available, otherwise None.

  Loaded lazily to avoid startup overhead and hard dependency on spaCy model.
  """
  try:
    import spacy  # type: ignore
    # Load small English model; if unavailable, this will raise and we'll fallback
    nlp = spacy.load("en_core_web_sm")
    return nlp.tokenizer
  except Exception:
    return None


def _iter_ngrams(tokens, n):
  """Yield n-grams from a token sequence without external deps."""
  if n <= 0:
    return
  length = len(tokens)
  if length < n:
    return
  for idx in range(length - n + 1):
    yield tuple(tokens[idx: idx + n])


def compute_diversity(all_texts_list):
  """Compute n-gram repetition and aggregate diversity metrics for generated texts.

  Returns a dict with keys like '2gram_repetition', '3gram_repetition', '4gram_repetition',
  and 'diversity' which aggregates across n-gram levels.
  """
  ngram_range = (2, 3, 4)

  tokenizer = _lazy_get_spacy_tokenizer()
  token_lists = []
  for sentence in all_texts_list:
    if tokenizer is not None:
      # Use spaCy tokenizer if available
      tokens = [str(token) for token in tokenizer(sentence)]
    else:
      # Fallback: simple whitespace tokenization
      tokens = sentence.split()
    token_lists.append(tokens)

  ngram_unique_sets = {n: set() for n in ngram_range}
  ngram_total_counts = defaultdict(int)

  metrics = {}
  for n in ngram_range:
    for tokens in token_lists:
      ngrams_for_tokens = list(_iter_ngrams(tokens, n))
      if not ngrams_for_tokens:
        continue
      ngram_unique_sets[n].update(ngrams_for_tokens)
      ngram_total_counts[n] += len(ngrams_for_tokens)

    total = ngram_total_counts[n]
    unique = len(ngram_unique_sets[n])
    if total == 0:
      repetition = 0.0
    else:
      repetition = 1.0 - (unique / float(total))
    metrics[f"{n}gram_repetition"] = repetition

  diversity_product = 1.0
  for n in ngram_range:
    repetition = metrics.get(f"{n}gram_repetition", 0.0)
    diversity_component = 1.0 - repetition
    diversity_product *= diversity_component
  metrics["diversity"] = diversity_product

  return metrics


def _sampling_worker(rank,
                     config_container,
                     steps,
                     num_batches,
                     semi_ar,
                     stride_length,
                     num_strides,
                     disable_ema,
                     out_queue):
  """GPU worker that generates text samples.

  Args:
    rank: CUDA device index to use in this worker.
    config_container: Serializable config container (from OmegaConf.to_container).
    steps: Diffusion sampling steps.
    num_batches: How many batches to sample in this worker.
    semi_ar: Whether to use semi-autoregressive sampler if available.
    stride_length, num_strides: Semi-AR related args (ignored otherwise).
    disable_ema: If True, disable EMA before sampling.
    out_queue: Multiprocessing queue to push results (rank, texts, refs|error_dict).
  """
  try:
    # Local imports to avoid pickling issues
    import torch as _torch
    import omegaconf as _oc

    # Pin this worker to its GPU if available
    if _torch.cuda.is_available():
      try:
        _torch.cuda.set_device(rank)
      except Exception:
        pass

    # Rebuild config and disable trajectory saving to avoid file races
    _config = _oc.OmegaConf.create(config_container)
    # try:
    #   _config.eval.save_sampling_trajectory = False
    #   _config.eval.max_trajectories_to_save = 0
    # except Exception:
    #   pass

    _tokenizer = dataloader.get_tokenizer(_config)
    _device = _torch.device(f'cuda:{rank}')
    _model = _load_from_checkpoint(config=_config, tokenizer=_tokenizer, device=_device)
    if disable_ema:
      _model.ema = None
    _model = _model.to(_device)

    local_text_samples = []
    local_reference_texts = []

    for _ in tqdm(range(num_batches), desc=f"Sampling on GPU {rank}"):
      if semi_ar and hasattr(_model, 'restore_model_and_semi_ar_sample'):
        _, intermediate_samples, _ = _model.restore_model_and_semi_ar_sample(
          stride_length=stride_length,
          num_strides=num_strides,
          dt=1 / steps)
        text_samples = intermediate_samples[-1]
        local_text_samples.extend(text_samples)
      else:
        samples, reference_texts = _model.restore_model_and_sample(
          num_steps=steps)
        text_samples = _model.tokenizer.batch_decode(samples)
        local_text_samples.extend(text_samples)
        local_reference_texts.extend(reference_texts)

    out_queue.put((rank, local_text_samples, local_reference_texts))
  except Exception as _e:
    import traceback as _tb
    out_queue.put((rank, {'error': f"{_e.__class__.__name__}: {_e}", 'traceback': _tb.format_exc()}, []))

def _load_from_checkpoint(config, tokenizer, device=None):
  if 'hf' in config.backbone:
    return diffusion.Diffusion(
      config, tokenizer=tokenizer).to('cuda')
  
  import warnings
  with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    return diffusion.Diffusion.load_from_checkpoint(
      config.eval.checkpoint_path,
      tokenizer=tokenizer,
      config=config, strict=False, device=device)


@L.pytorch.utilities.rank_zero_only
def _print_config(
  config: omegaconf.DictConfig,
  resolve: bool = True,
  save_cfg: bool = True) -> None:
  """Prints content of DictConfig using Rich library and its tree structure.
  
  Args:
    config (DictConfig): Configuration composed by Hydra.
    resolve (bool): Whether to resolve reference fields of DictConfig.
    save_cfg (bool): Whether to save the configuration tree to a file.
  """

  style = 'dim'
  tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)

  fields = config.keys()
  for field in fields:
    branch = tree.add(field, style=style, guide_style=style)

    config_section = config.get(field)
    branch_content = str(config_section)
    if isinstance(config_section, omegaconf.DictConfig):
      branch_content = omegaconf.OmegaConf.to_yaml(
        config_section, resolve=resolve)

    branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
  rich.print(tree)
  if save_cfg:
    with fsspec.open(
      '{}/config_tree.txt'.format(
        config.checkpointing.save_dir), 'w') as fp:
      rich.print(tree, file=fp)


@L.pytorch.utilities.rank_zero_only
def _print_batch(train_ds, valid_ds, tokenizer, k=64):
  for dl_type, dl in [
    ('train', train_ds), ('valid', valid_ds)]:
    print(f'Printing {dl_type} dataloader batch.')
    batch = next(iter(dl))
    print('Batch input_ids.shape', batch['input_ids'].shape)
    first = batch['input_ids'][0, :k]
    last = batch['input_ids'][0, -k:]
    print(f'First {k} tokens:', tokenizer.decode(first))
    print('ids:', first)
    print(f'Last {k} tokens:', tokenizer.decode(last))
    print('ids:', last)


def _get_validation_texts(config, tokenizer, max_samples, seed=42):
  """Get validation texts, subsampled to match generated samples count."""
  # Fix seed for reproducibility
  np.random.seed(seed)
  random.seed(seed)
  torch.manual_seed(seed)
  
  # Get validation dataloader
  _, valid_ds = dataloader.get_dataloaders(
    config, tokenizer, skip_train=True, valid_seed=seed)
  
  # Collect validation texts
  validation_texts = []
  for batch in tqdm(valid_ds, desc="Collecting validation texts"):
    texts = tokenizer.batch_decode(batch['input_ids']) #, skip_special_tokens=True)
    validation_texts.extend(texts)
    # Stop early if we have enough samples for efficiency
    if len(validation_texts) >= max_samples * 2:  # Collect more to subsample from
      break
  
  # Subsample to max_samples
  if len(validation_texts) > max_samples:
    validation_texts = np.random.choice(validation_texts, size=max_samples, replace=False).tolist()
  
  return validation_texts


def _compute_mauve(generated_texts, reference_texts):
  """Compute MAUVE metric."""
  if mauve is None:
    print("MAUVE not available. Skipping MAUVE computation.")
    return None
  
  result = mauve.compute_mauve(
    p_text=reference_texts,
    q_text=generated_texts,
    device_id=0 if torch.cuda.is_available() else -1,
    verbose=False
  )
  return result.mauve


def compute_number_of_parameters(module):
  return sum(p.numel() for p in module.parameters())


def clean_sample(sample: str) -> str:
  """Clean a sample by removing all <|endoftext|> and [PAD] tokens."""
  return sample.replace('[PAD]', '').strip()


def generate_samples(config, logger, tokenizer):
  logger.info('Generating samples.')
  model = _load_from_checkpoint(config=config,
                                tokenizer=tokenizer)
  model.gen_ppl_metric.reset()
  if config.eval.disable_ema:
    logger.info('Disabling EMA.')
    model.ema = None
  stride_length = config.sampling.stride_length
  num_strides = config.sampling.num_strides
  all_text_samples = []
  all_reference_texts = []
  
  # If multiple GPUs are available, distribute sampling batches across devices
  num_gpus = torch.cuda.device_count()
  parallel_ok = (not config.sampling.semi_ar) and (num_gpus > 1)
  if parallel_ok:
    logger.info(f"Parallel sampling enabled across {num_gpus} GPUs.")
    total_batches = int(config.sampling.num_sample_batches)
    per_gpu = total_batches // num_gpus
    remainder = total_batches % num_gpus
    batches_per_rank = [per_gpu + (1 if i < remainder else 0) for i in range(num_gpus)]

    # Build a serializable config copy for workers
    cfg_container = omegaconf.OmegaConf.to_container(config, resolve=True)

    # Ensure a CUDA-safe multiprocessing start method
    try:
      if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    except RuntimeError:
      pass

    result_q: mp.Queue = mp.Queue()
    procs = []
    for rank, nb in enumerate(batches_per_rank):
      if nb == 0:
        continue
      p = mp.Process(
        target=_sampling_worker,
        args=(rank, cfg_container, config.sampling.steps, nb,
              False, stride_length, num_strides, config.eval.disable_ema, result_q),
        daemon=True)
      p.start()
      procs.append(p)

    # Collect results
    received = 0
    expected = sum(1 for nb in batches_per_rank if nb > 0)
    worker_errors = []
    while received < expected:
      rank, payload, refs = result_q.get()
      if isinstance(payload, dict) and 'error' in payload:
        worker_errors.append((rank, payload['error']))
      else:
        all_text_samples.extend(payload)
        all_reference_texts.extend(refs)
      received += 1

    for p in procs:
      p.join()

    if worker_errors:
      logger.warning(f"Some sampling workers failed: {worker_errors}")

    # Compute generative perplexity once on aggregated texts
    if len(all_text_samples) > 0:
      model.compute_generative_perplexity(all_text_samples)
  else:
    for _ in tqdm(range(config.sampling.num_sample_batches)):
      if config.sampling.semi_ar:
        _, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(
          stride_length=stride_length,
          num_strides=num_strides,
          dt=1 / config.sampling.steps)
        text_samples = intermediate_samples[-1]
        all_text_samples.extend(text_samples)
        # Note: Samples generated using semi-ar method
        # need to to be processed before computing generative perplexity
        # since these samples contain numerous <|endoftext|> tokens
        # and diffusion.compute_generative_perplexity() discards
        # any text after the first EOS token.
      else:
        samples, reference_texts = model.restore_model_and_sample(
          num_steps=config.sampling.steps)
        text_samples = model.tokenizer.batch_decode(samples) #, skip_special_tokens=True)
        all_text_samples.extend(text_samples)
        all_reference_texts.extend(reference_texts)
        model.compute_generative_perplexity(text_samples)
    
  # Get validation texts matching the number of generated samples
  validation_texts = _get_validation_texts(config, tokenizer, max_samples=len(all_text_samples))
  logger.info(f'Collected {len(validation_texts)} validation texts.')
  
  all_text_samples = [clean_sample(sample) for sample in all_text_samples]
  validation_texts = [clean_sample(text) for text in validation_texts]

  # dump generated and validation/reference texts to file
  import os

  # Create directories for generated and validation samples
  generated_dir = 'generated_samples'
  validation_dir = 'validation_samples'
  reference_dir = 'reference_texts'
  paired_dir = 'paired_samples'
  os.makedirs(generated_dir, exist_ok=True)
  os.makedirs(validation_dir, exist_ok=True)
  os.makedirs(reference_dir, exist_ok=True)
  os.makedirs(paired_dir, exist_ok=True)

  # Write each generated sample to a separate file
  for idx, text in enumerate(all_text_samples):
    with open(os.path.join(generated_dir, f'sample_{idx}.txt'), 'w') as f:
      f.write(text)

  # Write each validation sample to a separate file
  for idx, text in enumerate(validation_texts):
    with open(os.path.join(validation_dir, f'sample_{idx}.txt'), 'w') as f:
      f.write(text)

  # Write each reference text (used for conditioning) to a separate file
  if all_reference_texts:
    for idx, text in enumerate(all_reference_texts):
      # Some entries can be None if null conditioning was used
      safe_text = '' if text is None else text
      with open(os.path.join(reference_dir, f'sample_{idx}.txt'), 'w') as f:
        f.write(safe_text)

  # Write each generated/reference pair to a single file
  if all_reference_texts:
    for idx, gen_text in enumerate(all_text_samples):
      ref_text = all_reference_texts[idx] if idx < len(all_reference_texts) else ''
      if ref_text is None:
        ref_text = ''
      content = f"sample:\n\n{gen_text}\n\n=== reference\n\n{ref_text}"
      with open(os.path.join(paired_dir, f'pair_{idx}.txt'), 'w') as f:
        f.write(content)

  # Compute and print generative perplexity
  if not config.sampling.semi_ar:
    gen_ppl = model.gen_ppl_metric.compute().item()
    print(f'Generative perplexity: {gen_ppl:.2f}')
  
  # Compute MAUVE metric
  logger.info('Computing MAUVE metric...')
  mauve_score = _compute_mauve(
    generated_texts=all_text_samples,
    reference_texts=validation_texts
  )
  
  if mauve_score is not None:
    print(f'MAUVE score: {mauve_score * 100:.4f}')

  # Compute diversity metrics for generated texts
  logger.info('Computing diversity metrics...')
  diversity_metrics = compute_diversity(all_text_samples)
  # Pretty print metrics
  print('Diversity metrics:')
  for key in sorted(diversity_metrics.keys()):
    value = diversity_metrics[key]
    if isinstance(value, float):
      print(f'  {key}: {value * 100:.6f}')
    else:
      print(f'  {key}: {value * 100}')

  # Compute diversity metrics for validation texts
  logger.info('Computing diversity metrics for validation texts...')
  validation_diversity_metrics = compute_diversity(validation_texts)
  print('Validation diversity metrics:')
  for key in sorted(validation_diversity_metrics.keys()):
    value = validation_diversity_metrics[key]
    if isinstance(value, float):
      print(f'  {key}: {value * 100:.6f}')
    else:
      print(f'  {key}: {value * 100}')
  
  logger.info(f'Number of parameters: {compute_number_of_parameters(model):,}')
  # print size of each module in the model
  # for name, module in model.named_modules():
  #   print(f'{name}: {compute_number_of_parameters(module):,}')
  # return 3 metrics: mauve, diversity, perplexity
  return mauve_score, diversity_metrics['diversity'], gen_ppl

def test_condition_embedding_dependence(config, logger, tokenizer, num_test_batches=5, std_dev=1.0):
  """
  Simple test: load model, use fixed input with zero vs random embeddings, 
  compute divergence between backbone outputs.
  """
  logger.info('Testing condition embedding dependence (simple version).')
  
  model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
  if config.eval.disable_ema:
    logger.info('Disabling EMA.')
    model.ema = None
  
  model.backbone.eval()
  
  batch_size = config.loader.eval_batch_size
  cond_dim_embedding = model.text_embedder.cond_dim if model.text_embedder is not None else config.model.cond_dim_embedding
  
  # Get some fixed inputs from validation data
  _, valid_ds = dataloader.get_dataloaders(config, tokenizer, skip_train=True, valid_seed=42)
  batch = next(iter(valid_ds))
  fixed_input = batch['input_ids'][:batch_size].to(model.device)
  
  # Fixed sigma (timestep)
  fixed_sigma = torch.zeros(batch_size, device=model.device)
  
  with torch.no_grad():
    # Zero embedding condition
    zero_condition = torch.zeros(batch_size, cond_dim_embedding, device=model.device)
    
    # Random normal embedding condition  
    random_condition = torch.randn(batch_size, cond_dim_embedding, device=model.device)
    
    # Forward pass with zero condition
    logits_zero = model.backbone(fixed_input, fixed_sigma, zero_condition)
    
    # Forward pass with random condition  
    logits_random = model.backbone(fixed_input, fixed_sigma, random_condition)
    
    # Convert to probabilities
    probs_zero = torch.softmax(logits_zero, dim=-1)
    probs_random = torch.softmax(logits_random, dim=-1)
    
    # Compute KL divergence: KL(random || zero)
    kl_div = torch.nn.functional.kl_div(
      torch.log(probs_zero + 1e-8), 
      probs_random, 
      reduction='batchmean'
    )
    
    # Compute Jensen-Shannon divergence
    m = (probs_zero + probs_random) / 2
    js_div = 0.5 * torch.nn.functional.kl_div(torch.log(probs_zero + 1e-8), m, reduction='batchmean') + \
             0.5 * torch.nn.functional.kl_div(torch.log(probs_random + 1e-8), m, reduction='batchmean')

    mse = torch.nn.functional.mse_loss(logits_zero, logits_random)
  
  model.backbone.train()
  
  print(f"\n{'='*50}")
  print("CONDITION EMBEDDING DEPENDENCE TEST")
  print(f"{'='*50}")
  print(f"KL Divergence (random || zero): {kl_div.item():.6f}")
  print(f"JS Divergence: {js_div.item():.6f}")
  print(f"MSE: {mse.item():.12f}")
  print(f"{'='*50}")


def _ppl_eval(config, logger, tokenizer):
  logger.info('Starting Zero Shot Eval.')

  model = _load_from_checkpoint(config=config,
                                tokenizer=tokenizer)
  if config.eval.disable_ema:
    logger.info('Disabling EMA.')
    model.ema = None

  wandb_logger = None
  if config.get('wandb', None) is not None:
    wandb_logger = L.pytorch.loggers.WandbLogger(
      config=omegaconf.OmegaConf.to_object(config),
      ** config.wandb)
  callbacks = []
  if 'callbacks' in config:
    for _, callback in config.callbacks.items():
      callbacks.append(hydra.utils.instantiate(callback))
  trainer = hydra.utils.instantiate(
    config.trainer,
    default_root_dir=os.getcwd(),
    callbacks=callbacks,
    strategy=hydra.utils.instantiate(config.strategy),
    logger=wandb_logger)
  _, valid_ds = dataloader.get_dataloaders(
    config, tokenizer, skip_train=True, valid_seed=config.seed)
  trainer.validate(model, valid_ds)


def _train(config, logger, tokenizer):
  logger.info('Starting Training.')
  wandb_logger = None
  if config.get('wandb', None) is not None:
    wandb_logger = L.pytorch.loggers.WandbLogger(
      config=omegaconf.OmegaConf.to_object(config),
      ** config.wandb)

  if (config.checkpointing.resume_from_ckpt
      and config.checkpointing.resume_ckpt_path is not None
      and utils.fsspec_exists(
        config.checkpointing.resume_ckpt_path)):
    ckpt_path = config.checkpointing.resume_ckpt_path
  else:
    ckpt_path = None

  # Lightning callbacks
  callbacks = []
  if 'callbacks' in config:
    for _, callback in config.callbacks.items():
      callbacks.append(hydra.utils.instantiate(callback))

  train_ds, valid_ds = dataloader.get_dataloaders(
    config, tokenizer)
  _print_batch(train_ds, valid_ds, tokenizer)

  model = diffusion.Diffusion(
    config, tokenizer=valid_ds.tokenizer)

  trainer = hydra.utils.instantiate(
    config.trainer,
    default_root_dir=os.getcwd(),
    callbacks=callbacks,
    strategy=hydra.utils.instantiate(config.strategy),
    logger=wandb_logger)

  state_dict = None
  if ckpt_path:
    if ckpt_path.endswith('ckpt'):
      print(f"Loading checkpoint from {ckpt_path}")
      state_dict = torch.load(ckpt_path)["state_dict"]
    elif ckpt_path.endswith('safetensors'):
      print(f"Loading safetensors from {ckpt_path}")
      state_dict = load_file(ckpt_path)
    else:
      raise ValueError(f"Unknown checkpoint format for {ckpt_path}")
    model.load_state_dict(state_dict, strict=False)
  else:
    print('Training from scratch')
  trainer.fit(model, train_ds, valid_ds)


@hydra.main(version_base=None, config_path='configs',
            config_name='config')
def main(config):
  """Main entry point for training."""
  L.seed_everything(config.seed)
  _print_config(config, resolve=True, save_cfg=True)
  
  logger = utils.get_logger(__name__)
  tokenizer = dataloader.get_tokenizer(config)

  if config.mode == 'sample_eval':
    mauve_score, diversity, gen_ppl = generate_samples(config, logger, tokenizer)
    logger.info(f"MAUVE score: {mauve_score * 100:.4f}")
    logger.info(f"Diversity: {diversity * 100:.4f}")
    logger.info(f"Generative perplexity: {gen_ppl:.2f}")

    # Append metrics to a JSONL file. Prefer eval.metrics_file if provided.
    metrics_file = None
    if 'eval' in config and isinstance(config.eval, omegaconf.DictConfig):
      metrics_file = config.eval.get('metrics_file', None)
    if metrics_file is None or metrics_file == '':
      metrics_file = os.path.join(os.path.dirname(__file__), 'metrics.jsonl')

    metrics_dir = os.path.dirname(metrics_file) or '.'
    os.makedirs(metrics_dir, exist_ok=True)
    if not os.path.exists(metrics_file):
      with open(metrics_file, 'w') as f:
        pass  # create the file if it doesn't exist

    with open(metrics_file, 'a') as f:
      f.write(json.dumps({
        'mauve': mauve_score,
        'diversity': diversity,
        'gen_ppl': gen_ppl,
        'seed': config.seed,
        'checkpoint': config.eval.checkpoint_path
      }) + '\n')
  elif config.mode == 'ppl_eval':
    _ppl_eval(config, logger, tokenizer)
  elif config.mode == 'condition_dependence_test':
    test_condition_embedding_dependence(config, logger, tokenizer)
  else:
    _train(config, logger, tokenizer)


if __name__ == '__main__':
  main()