"""
Console logger utilities.
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
Copied from https://github.com/kuleshov-group/mdlm

"""
import os

import logging
import math

import fsspec
import lightning
import lightning as L
import torch
from timm.scheduler import CosineLRScheduler
from torch.distributed import init_process_group
import datetime
import omegaconf

import rich.syntax
import rich.tree

import itertools
import models


def fsspec_exists(filename):
  """Check if a file exists using fsspec."""
  fs, _ = fsspec.core.url_to_fs(filename)
  return fs.exists(filename)


def fsspec_listdir(dirname):
  """Listdir in manner compatible with fsspec."""
  fs, _ = fsspec.core.url_to_fs(dirname)
  return fs.ls(dirname)


def fsspec_mkdirs(dirname, exist_ok=True):
  """Mkdirs in manner compatible with fsspec."""
  fs, _ = fsspec.core.url_to_fs(dirname)
  fs.makedirs(dirname, exist_ok=exist_ok)


def print_nans(tensor, name):
  if torch.isnan(tensor).any():
    print(name, tensor)


class CosineDecayWarmupLRScheduler(
  CosineLRScheduler,
  torch.optim.lr_scheduler._LRScheduler):
  """Wrap timm.scheduler.CosineLRScheduler
  Enables calling scheduler.step() without passing in epoch.
  Supports resuming as well.
  Adapted from:
    https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
  """

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self._last_epoch = -1
    self.step(epoch=0)

  def step(self, epoch=None):
    if epoch is None:
      self._last_epoch += 1
    else:
      self._last_epoch = epoch
    # We call either step or step_update, depending on
    # whether we're using the scheduler every epoch or every
    # step.
    # Otherwise, lightning will always call step (i.e.,
    # meant for each epoch), and if we set scheduler
    # interval to "step", then the learning rate update will
    # be wrong.
    if self.t_in_epochs:
      super().step(epoch=self._last_epoch)
    else:
      super().step_update(num_updates=self._last_epoch)


class LoggingContext:
  """Context manager for selective logging."""
  def __init__(self, logger, level=None, handler=None, close=True):
    self.logger = logger
    self.level = level
    self.handler = handler
    self.close = close

  def __enter__(self):
    if self.level is not None:
      self.old_level = self.logger.level
      self.logger.setLevel(self.level)
    if self.handler:
      self.logger.addHandler(self.handler)

  def __exit__(self, et, ev, tb):
    if self.level is not None:
      self.logger.setLevel(self.old_level)
    if self.handler:
      self.logger.removeHandler(self.handler)
    if self.handler and self.close:
      self.handler.close()


def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
  """Initializes multi-GPU-friendly python logger."""

  logger = logging.getLogger(name)
  logger.setLevel(level)

  # this ensures all logging levels get marked with the rank zero decorator
  # otherwise logs would get multiplied for each GPU process in multi-GPU setup
  for level in ('debug', 'info', 'warning', 'error',
                'exception', 'fatal', 'critical'):
    setattr(logger,
            level,
            lightning.pytorch.utilities.rank_zero_only(
              getattr(logger, level)))

  return logger


class Sampler:
  def __init__(self, shape):
    self.shape = shape

  def _sampling_noise(self):
    pass
  
  def _hard_sample(self, logits):
    pass

  def _soft_sample(self, logits):
    return 0

  def sample(self, logits):
    noise = self._sampling_noise()
    noise = noise[: logits.shape[0], :]
    logits = logits + noise.to(
      dtype=logits.dtype, device=logits.device)
    hard_sample = self._hard_sample(logits)
    soft_sample = self._soft_sample(logits)
    return soft_sample + (hard_sample - soft_sample).detach()


class TopKSampler(Sampler):
  def __init__(self, k, shape, gamma_tau=1.0):
    super().__init__(shape)
    self.k = k
    self.gamma_tau = gamma_tau
    self.num_betas = 10
    self.sampler = torch.distributions.gamma.Gamma(
      1 / k * torch.ones(self.num_betas, * self.shape), 1.0)

  def _sampling_noise(self):
    noise = self.sampler.sample()
    beta = self.k / torch.arange(1, self.num_betas + 1, 1,
                                 dtype=torch.float32)
    beta = beta[:, None, None]
    assert beta.ndim == noise.ndim
    s = noise / beta
    s = torch.sum(s, axis=0)
    s = s - math.log(10.0)
    s = self.gamma_tau * (s / self.k)
    return s

  def _hard_sample(self, logits):
    assert logits.ndim == 2
    thresholds, _ = torch.sort(logits, dim=-1)
    thresholds = thresholds[:, - self.k][:, None]
    return (logits >= thresholds).type(logits.dtype)

  def _soft_sample(self, logits):
    soft_top_k = logits - torch.mean(logits, dim=-1,
                                     keepdim=True)
    return soft_top_k / torch.norm(soft_top_k, dim=-1,
                                   keepdim=True)


class DeterministicTopK(TopKSampler):
  def __init__(self, k):
    super().__init__(k, shape=(1, 1))

  def _sampling_noise(self):
    return 0

  def discreize(self, x):
    hard_sample = self._hard_sample(x)
    soft_sample = self._soft_sample(x)
    return soft_sample + (hard_sample - soft_sample).detach()

class GumbelSampler(Sampler):

  def __init__(self, shape, temperature=1.0):
    super().__init__(shape)
    self.temperature = temperature

  def _sampling_noise(self):
    return - (1e-10 - (
      torch.rand(* self.shape) + 1e-10).log()).log()

  def _hard_sample(self, logits):
    assert logits.ndim == 2
    indices = torch.argmax(logits, dim=-1)
    zeros = logits * 0
    ones = torch.ones_like(logits[:, :, :1])
    return torch.scatter(zeros, -1, indices[:, :, None],
                         ones)

  def _soft_sample(self, logits):
    return torch.nn.functional.softmax(
      logits / self.temperature, dim=-1)


class BinarySampler(GumbelSampler):

  def sample(self, probs):
    pos_noise = self._sampling_noise().to(
      dtype=probs.dtype, device=probs.device)
    neg_noise = self._sampling_noise().to(
      dtype=probs.dtype, device=probs.device)
    del_noise_exp = (neg_noise - pos_noise).exp()
    hard_sample = (probs * (1 + del_noise_exp)
                   > 1).to(probs.dtype)
    soft_sample = probs / (probs + (1 - probs) * del_noise_exp)
    return soft_sample + (hard_sample - soft_sample).detach()


class GaussianSampler:
  def __init__(self):
    self.softplus = torch.nn.Softplus()

  def sample(self, x):
    assert x.ndim == 2
    n = x.shape[-1] // 2
    mu = x[:, :n]
    sigma = self.softplus(x[:, n:]).sqrt()
    return mu + sigma * torch.randn_like(mu)
    
def ddp_setup(rank: int, world_size: int):
  """
  Args:
  rank: Unique identifier of each process
  world_size: Total number of processes
  """
  os.environ["MASTER_ADDR"] = "localhost"
  os.environ["MASTER_PORT"] = "6000"
  torch.cuda.set_device(rank)
  timeout = datetime.timedelta(minutes=1)
  init_process_group(backend="nccl", rank=rank, world_size=world_size, timeout=timeout)  
  
  
"""
Our own functions

""" 

def load_from_checkpoint(config, logger, tokenizer, model):  
  logger.info(f"Model created with vocab_size: {model.vocab_size}")
  logger.info(f"Mask token ID: {model.mask_index}")
  
  # ===== VOCABULARY AND EMBEDDING RESIZE LOGIC =====
  # (required when using different types of tokenizer. e.g. ClipTokenizer)
  # Check current embedding size vs required size
  
  current_embed_size = None
  embed_dim = None
  
  if hasattr(model.backbone, 'vocab_embed') and hasattr(model.backbone.vocab_embed, 'embedding'):
      embedding_attr = model.backbone.vocab_embed.embedding
        
      if isinstance(embedding_attr, torch.nn.Parameter):
          current_embed_size = embedding_attr.shape[0]
          embed_dim = embedding_attr.shape[1]
            
      elif hasattr(embedding_attr, 'num_embeddings'):
          current_embed_size = embedding_attr.num_embeddings
          embed_dim = embedding_attr.embedding_dim
          
  # Check bounds
  max_token_id = max(model.mask_index, model.vocab_size - 1)
  required_embed_size = max_token_id + 1  # +1 because indices start at 0    
   
  if current_embed_size is not None and current_embed_size != required_embed_size:
        
      # Resize input embeddings
      if isinstance(model.backbone.vocab_embed.embedding, torch.nn.Parameter):
          old_weight = model.backbone.vocab_embed.embedding.data
          new_weight = torch.zeros(required_embed_size, embed_dim, 
                                   dtype=old_weight.dtype, device=old_weight.device)
            
          with torch.no_grad():
              # Copy old weights
              new_weight[:current_embed_size] = old_weight
              # Initialize new tokens (including mask token) randomly  
              if required_embed_size > current_embed_size:
                  torch.nn.init.normal_(new_weight[current_embed_size:], mean=0.0, std=0.02)
            
          # Replace the parameter
          model.backbone.vocab_embed.embedding = torch.nn.Parameter(new_weight)
            
      else:
          old_embeddings = model.backbone.vocab_embed.embedding
          new_embeddings = torch.nn.Embedding(
              required_embed_size,
              embed_dim,
              padding_idx=old_embeddings.padding_idx if hasattr(old_embeddings, 'padding_idx') else None
          )
            
          with torch.no_grad():
              new_embeddings.weight[:current_embed_size] = old_embeddings.weight
              if required_embed_size > current_embed_size:
                  torch.nn.init.normal_(new_embeddings.weight[current_embed_size:], mean=0.0, std=0.02)
            
          model.backbone.vocab_embed.embedding = new_embeddings
          
      # Resize output layer
      if hasattr(model.backbone, 'output_layer') and hasattr(model.backbone.output_layer, 'linear'):
          old_output = model.backbone.output_layer.linear
          new_output = torch.nn.Linear(
              old_output.in_features,
              required_embed_size,
              bias=old_output.bias is not None
          )
            
          with torch.no_grad():
              # Copy old weights
              new_output.weight[:current_embed_size] = old_output.weight
              # Initialize new output weights
              if required_embed_size > current_embed_size:
                  torch.nn.init.normal_(new_output.weight[current_embed_size:], mean=0.0, std=0.02)
                
              # Handle bias
              if new_output.bias is not None and old_output.bias is not None:
                  new_output.bias[:current_embed_size] = old_output.bias
                  if required_embed_size > current_embed_size:
                      new_output.bias[current_embed_size:] = 0.0
                      
            
          model.backbone.output_layer.linear = new_output
        
      # Update model's vocab_size to match the actual embedding size
      model.vocab_size = required_embed_size
        
  elif current_embed_size is not None and max_token_id >= current_embed_size:
      # This should not happen if the resizing logic above works
      raise ValueError(f"CRITICAL: Token ID {max_token_id} is out of bounds for embedding size {current_embed_size}")
  else:
      logger.info("No embedding resizing needed")
      
  # ==== CHECKPOINT PATH ====
  
  ckpt_path = config.checkpointing.resume_ckpt_path
  
  # ==== IF NO CHECKPOINT ====
  if not ckpt_path or not os.path.exists(ckpt_path):
    print(FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}"))
    if config.training.ema > 0:
        model._ema_params = list(itertools.chain(
            model.backbone.parameters(), 
            model.noise.parameters()
        ))
        if model.use_image_conditioning and model.image_encoder is not None:
            model._ema_params.extend(model.image_encoder.parameters())
            
        model.ema = models.ema.ExponentialMovingAverage(
            model._ema_params, decay=config.training.ema
        )
        model.ema.move_shadow_params_to_device(model.device)
        
    return model
    
  # ==== IF CHECKPOINT ====

  logger.info("Loading checkpoint from {ckpt_path}")
  checkpoint = torch.load(ckpt_path, map_location='cuda')

  # Load model weights
  model.load_state_dict(checkpoint['state_dict'], strict=False)  

  own = model.state_dict()
  ckpt = torch.load(ckpt_path, map_location='cpu')

  missing = [k for k in own.keys() if k not in ckpt.get('state_dict', ckpt)]
  unexpected = [k for k in (ckpt.get('state_dict', ckpt)).keys() if k not in own]
  matched = []
  mismatch = []

  for k, v in (ckpt.get('state_dict', ckpt)).items():
    if k in own:
        if own[k].shape == v.shape:
            matched.append(k)
        else:
            mismatch.append((k, tuple(v.shape), tuple(own[k].shape)))
  
  # Re-initialize EMA after loading weights
  if config.training.ema > 0:
    model._ema_params = list(itertools.chain(model.backbone.parameters(), model.noise.parameters()))
    model.ema = models.ema.ExponentialMovingAverage(
        model._ema_params,
        decay=config.training.ema)
    # If EMA state exists in checkpoint, load it, else keep fresh EMA
    if 'ema' in checkpoint:
        try:
            model.ema.load_state_dict(checkpoint['ema'])
        except Exception as e:
            print(f"Warning: Could not load EMA state dict: {e}")
    # Move EMA to correct device
    model.ema.move_shadow_params_to_device(model.device)
  else:
    model.ema = None
    model._ema_params = []

  # Restore training progress if available
  if 'loops' in checkpoint:
    try:
        model.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
        model.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
    except Exception:
        pass

  return model
  
@L.pytorch.utilities.rank_zero_only
def print_config(
  config: omegaconf.DictConfig,
  resolve: bool = True,
  save_cfg: bool = True) -> None:
  """Prints content of DictConfig using Rich library and its tree structure."""

  style = 'dim'
  tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)

  fields = config.keys()
  for field in fields:
    branch = tree.add(field, style=style, guide_style=style)

    config_section = config.get(field)
    branch_content = str(config_section)
    if isinstance(config_section, omegaconf.DictConfig):
      branch_content = omegaconf.OmegaConf.to_yaml(
        config_section, resolve=resolve)

    branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
  rich.print(tree)
  if save_cfg:
    with fsspec.open(
      '{}/config_tree.txt'.format(
        config.checkpointing.save_dir), 'w') as fp:
      rich.print(tree, file=fp)
      
@L.pytorch.utilities.rank_zero_only
def print_batch(train_ds, valid_ds, tokenizer, k=64):
  for dl_type, dl in [
    ('train', train_ds), ('valid', valid_ds)]:
    batch = next(iter(dl))
    first = batch['input_ids'][0, :k]
    last = batch['input_ids'][0, -k:]
    print(f'First {k} tokens:', tokenizer.decode(first))
    print('ids:', first)
    print(f'Last {k} tokens:', tokenizer.decode(last))
    print('ids:', last)

