import itertools
import os
import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Tuple, Optional

import fsspec
import hydra
import lightning as L
import omegaconf
import rich.syntax
import rich.tree

import dataloader
import diffusion
import utils

omegaconf.OmegaConf.register_new_resolver(
  'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
  'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
  'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
  'div_up', lambda x, y: (x + y - 1) // y)


def _load_from_checkpoint(config, tokenizer):
  if 'hf' in config.backbone:
    # For HuggingFace models, create a new model and let it download from HF
    return diffusion.Diffusion(
      config, tokenizer=tokenizer).to('cuda')
  
  # For local checkpoints, load from file
  if not os.path.exists(config.eval.checkpoint_path):
    raise FileNotFoundError(
      f"Checkpoint file not found: {config.eval.checkpoint_path}. "
      f"If you're trying to use a HuggingFace model, make sure to set backbone=hf_dit"
    )
  
  return diffusion.Diffusion.load_from_checkpoint(
    config.eval.checkpoint_path,
    tokenizer=tokenizer,
    config=config)


@L.pytorch.utilities.rank_zero_only
def _print_config(
  config: omegaconf.DictConfig,
  resolve: bool = True,
  save_cfg: bool = True) -> None:
  """Prints content of DictConfig using Rich library and its tree structure."""
  style = 'dim'
  tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)

  fields = config.keys()
  for field in fields:
    branch = tree.add(field, style=style, guide_style=style)

    config_section = config.get(field)
    branch_content = str(config_section)
    if isinstance(config_section, omegaconf.DictConfig):
      branch_content = omegaconf.OmegaConf.to_yaml(
        config_section, resolve=resolve)

    branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
  rich.print(tree)
  if save_cfg:
    with fsspec.open(
      '{}/config_tree.txt'.format(
        config.checkpointing.save_dir), 'w') as fp:
      rich.print(tree, file=fp)


def compute_entropy_reward(model, particles: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """
    Compute entropy-based reward for particles at time step t.
    Lower entropy = higher reward (better quality).
    
    Args:
        model: MDLM model
        particles: tensor of shape (num_particles, seq_len)
        t: time step tensor of shape (num_particles, 1)
    
    Returns:
        rewards: tensor of shape (num_particles,) with entropy-based rewards
    """
    with torch.no_grad():
        # Compute mask-conditional entropy for each particle
        entropies = model._compute_mask_conditional_entropy(particles, t)
        
        # Convert entropy to reward: lower entropy = higher reward
        # Use negative entropy and add offset to make rewards positive
        max_entropy = torch.log2(torch.tensor(model.vocab_size, dtype=torch.float32, device=particles.device))
        rewards = max_entropy - entropies  # Higher reward for lower entropy
        
        # Normalize rewards to [0, 1] range
        rewards = torch.clamp(rewards / max_entropy, 0.0, 1.0)
        
    return rewards


def compute_potential_scores(
    rewards: torch.Tensor, 
    lambda_weight: float, 
    potential_type: str = 'max'
) -> torch.Tensor:
    """
    Compute potential function scores for resampling.
    
    Args:
        rewards: tensor of shape (num_particles,)
        lambda_weight: temperature parameter
        potential_type: type of potential function ('max', 'mean', etc.)
    
    Returns:
        scores: tensor of shape (num_particles,)
    """
    if potential_type == 'max':
        # G_t = exp(lambda * r_t)
        scores = torch.exp(lambda_weight * rewards)
    elif potential_type == 'mean':
        # Alternative potential function
        scores = lambda_weight * rewards + 1.0
    else:
        raise ValueError(f"Unknown potential type: {potential_type}")
    
    return scores


def resample_particles(
    particles: torch.Tensor, 
    scores: torch.Tensor, 
    num_particles: int
) -> torch.Tensor:
    """
    Resample particles based on potential scores using multinomial sampling.
    
    Args:
        particles: tensor of shape (num_particles, seq_len)
        scores: tensor of shape (num_particles,)
        num_particles: number of particles to resample
    
    Returns:
        resampled_particles: tensor of shape (num_particles, seq_len)
    """
    # Convert scores to probabilities
    probabilities = F.softmax(scores, dim=0)
    
    # Multinomial resampling
    indices = torch.multinomial(probabilities, num_particles, replacement=True)
    
    # Select particles based on indices
    resampled_particles = particles[indices]
    
    return resampled_particles


def fk_steering_mdlm(
    model,
    num_particles: int,
    num_steps: int,
    resample_interval: int,
    lambda_weight: float,
    potential_type: str = 'max',
    eps: float = 1e-5
) -> Tuple[torch.Tensor, List[torch.Tensor], dict]:
    """
    FK steering algorithm for MDLM using state entropy as reward.
    
    Args:
        model: MDLM model
        num_particles: number of particles to maintain
        num_steps: total number of diffusion steps
        resample_interval: how often to resample (every N steps)
        lambda_weight: temperature parameter for potential function
        potential_type: type of potential function
        eps: minimum time value
    
    Returns:
        best_sample: best generated sample
        all_particles: final particles
        detailed_logs: dictionary containing detailed tracking information
    """
    device = model.device
    seq_len = model.config.model.length
    
    # Step 1: Initialize particles from prior distribution
    particles = model._sample_prior(num_particles, seq_len).to(device)
    
    # Initialize detailed tracking
    detailed_logs = {
        'particle_entropies': np.zeros((num_steps + 1, num_particles)),  # Shape: (steps, particles)
        'particle_rewards': np.zeros((num_steps + 1, num_particles)),
        'particle_ids': np.arange(num_particles)[None, :].repeat(num_steps + 1, axis=0),  # Track particle identity
        'resampling_steps': [],
        'resampling_info': [],
        'timesteps': [],
        'particle_lineage': [list(range(num_particles))],  # Track which original particle each current particle came from
    }
    
    timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
    dt = (1 - eps) / num_steps
    p_x0_cache = None
    
    print(f"Starting FK steering with {num_particles} particles, {num_steps} steps")
    print(f"Resampling every {resample_interval} steps with lambda={lambda_weight}")
    
    # Step 2: Iterative generation and resampling
    for step in range(num_steps):
        t_val = timesteps[step]
        t = t_val * torch.ones(num_particles, 1, device=device)
        
        # Record timestep
        detailed_logs['timesteps'].append(t_val.item())
        
        # Compute entropy and rewards for current state
        current_entropies = model._compute_mask_conditional_entropy(particles, t)
        current_rewards = compute_entropy_reward(model, particles, t)
        
        # Record entropies and rewards
        detailed_logs['particle_entropies'][step] = current_entropies.detach().cpu().numpy()
        detailed_logs['particle_rewards'][step] = current_rewards.detach().cpu().numpy()
        
        # Generate next state for all particles
        new_particles = []
        for i in range(num_particles):
            particle = particles[i:i+1]  # Keep batch dimension
            t_particle = t[i:i+1]
            
            # Use the same sampling method as the original model
            if model.sampler == 'ddpm':
                new_particle = model._ddpm_update(particle, t_particle, dt)
            elif model.sampler == 'ddpm_cache':
                # Note: caching doesn't work well with multiple particles
                # so we use regular ddpm update
                new_particle = model._ddpm_update(particle, t_particle, dt)
            else:
                new_particle = model._analytic_update(particle, t_particle, dt)
            
            new_particles.append(new_particle)
        
        particles = torch.cat(new_particles, dim=0)
        
        # Check if we should resample at this step
        if (step + 1) % resample_interval == 0 and step < num_steps - 1:
            print(f"Resampling at step {step + 1}/{num_steps}")
            
            # Compute entropy-based rewards for resampling
            t_resample = timesteps[step + 1] * torch.ones(num_particles, 1, device=device)
            resample_entropies = model._compute_mask_conditional_entropy(particles, t_resample)
            resample_rewards = compute_entropy_reward(model, particles, t_resample)
            
            # Compute potential scores
            potential_scores = compute_potential_scores(
                resample_rewards, lambda_weight, potential_type
            )
            
            # Record pre-resampling state
            pre_resample_entropies = resample_entropies.detach().cpu().numpy()
            pre_resample_rewards = resample_rewards.detach().cpu().numpy()
            
            # Resample particles and track which particles were selected
            probabilities = F.softmax(potential_scores, dim=0)
            selected_indices = torch.multinomial(probabilities, num_particles, replacement=True)
            particles = particles[selected_indices]
            
            # Update particle lineage tracking
            current_lineage = detailed_logs['particle_lineage'][-1]
            new_lineage = [current_lineage[idx.item()] for idx in selected_indices]
            detailed_logs['particle_lineage'].append(new_lineage)
            
            # Record resampling information
            detailed_logs['resampling_steps'].append(step + 1)
            detailed_logs['resampling_info'].append({
                'selected_indices': selected_indices.cpu().numpy(),
                'probabilities': probabilities.detach().cpu().numpy(),
                'pre_resample_entropies': pre_resample_entropies,
                'pre_resample_rewards': pre_resample_rewards,
                'particle_lineage': new_lineage.copy()
            })
            
            # Print resampling statistics
            print(f"  Entropy rewards - Mean: {resample_rewards.mean():.4f}, Std: {resample_rewards.std():.4f}")
            print(f"  Min/Max rewards: {resample_rewards.min():.4f}/{resample_rewards.max():.4f}")
            print(f"  Selected particles: {selected_indices.cpu().numpy()}")
        else:
            # No resampling, maintain current lineage
            if step == 0 or (step + 1) % resample_interval != 0:
                detailed_logs['particle_lineage'].append(detailed_logs['particle_lineage'][-1].copy())
    
    # Record final timestep
    detailed_logs['timesteps'].append(eps)
    
    # Final denoising step
    if model.config.sampling.noise_removal:
        t_final = timesteps[-1] * torch.ones(particles.shape[0], 1, device=device)
        final_entropies = model._compute_mask_conditional_entropy(particles, t_final)
        final_rewards = compute_entropy_reward(model, particles, t_final)
        
        # Record final state
        detailed_logs['particle_entropies'][num_steps] = final_entropies.detach().cpu().numpy()
        detailed_logs['particle_rewards'][num_steps] = final_rewards.detach().cpu().numpy()
        
        final_particles = []
        for i in range(num_particles):
            particle = particles[i:i+1]
            t_particle = t_final[i:i+1]
            
            if model.sampler == 'analytic':
                denoised = model._denoiser_update(particle, t_particle)
            else:
                unet_conditioning = model.noise(t_particle)[0]
                denoised = model.forward(particle, unet_conditioning).argmax(dim=-1)
            
            final_particles.append(denoised)
        
        particles = torch.cat(final_particles, dim=0)
    else:
        # Record final state even without noise removal
        t_final = eps * torch.ones(particles.shape[0], 1, device=device)
        final_entropies = model._compute_mask_conditional_entropy(particles, t_final)
        final_rewards = compute_entropy_reward(model, particles, t_final)
        detailed_logs['particle_entropies'][num_steps] = final_entropies.detach().cpu().numpy()
        detailed_logs['particle_rewards'][num_steps] = final_rewards.detach().cpu().numpy()
    
    # Select best sample based on final entropy
    t_final = eps * torch.ones(particles.shape[0], 1, device=device)
    final_rewards = compute_entropy_reward(model, particles, t_final)
    best_idx = torch.argmax(final_rewards)
    best_sample = particles[best_idx:best_idx+1]
    
    # Add summary statistics to logs
    detailed_logs['final_rewards'] = final_rewards.detach().cpu().numpy()
    detailed_logs['best_particle_idx'] = best_idx.item()
    detailed_logs['num_particles'] = num_particles
    detailed_logs['num_steps'] = num_steps
    detailed_logs['resample_interval'] = resample_interval
    detailed_logs['lambda_weight'] = lambda_weight
    
    print(f"Selected best sample with reward: {final_rewards[best_idx]:.4f}")
    print(f"Final reward statistics - Mean: {final_rewards.mean():.4f}, Std: {final_rewards.std():.4f}")
    
    return best_sample, particles, detailed_logs


def generate_samples_fk_steering(config, logger, tokenizer):
    """Generate samples using FK steering with entropy-based rewards."""
    logger.info('Generating samples with FK steering.')
    
    model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
    model.gen_ppl_metric.reset()
    
    if config.eval.disable_ema:
        logger.info('Disabling EMA.')
        model.ema = None
    
    # FK steering parameters from config
    num_particles = config.fk_steering.get('num_particles', 8)
    resample_interval = config.fk_steering.get('resample_interval', 50)
    lambda_weight = config.fk_steering.get('lambda_weight', 5.0)
    potential_type = config.fk_steering.get('potential_type', 'max')
    
    # Prepare model for sampling
    if model.ema:
        model.ema.store(itertools.chain(
            model.backbone.parameters(),
            model.noise.parameters()))
        model.ema.copy_to(itertools.chain(
            model.backbone.parameters(),
            model.noise.parameters()))
    
    model.backbone.eval()
    model.noise.eval()
    
    all_samples = []
    all_entropy_trajectories = []
    
    # Generate samples using FK steering
    for batch_idx in range(config.sampling.num_sample_batches):
        print(f"\n=== FK Steering Batch {batch_idx + 1}/{config.sampling.num_sample_batches} ===")
        
        best_sample, all_particles, detailed_logs = fk_steering_mdlm(
            model=model,
            num_particles=num_particles,
            num_steps=config.sampling.steps,
            resample_interval=resample_interval,
            lambda_weight=lambda_weight,
            potential_type=potential_type
        )
        
        # Decode samples
        text_samples = model.tokenizer.batch_decode(all_particles)
        best_text = model.tokenizer.decode(best_sample[0])
        
        all_samples.extend(text_samples)
        all_entropy_trajectories.append(detailed_logs)
        
        # Compute generative perplexity
        model.compute_generative_perplexity(text_samples)
        
        # Print results for this batch
        print(f"\nBest sample from batch {batch_idx + 1}:")
        print(f"'{best_text}'")
        
        print(f"\nAll samples from batch {batch_idx + 1}:")
        for i, text in enumerate(text_samples):
            print(f"  Particle {i + 1}: '{text}'")
        
        # Print entropy evolution statistics
        entropy_array = detailed_logs['particle_entropies']
        print(f"\nEntropy evolution statistics:")
        print(f"  Initial entropy: {entropy_array[0].mean():.4f} ± {entropy_array[0].std():.4f}")
        print(f"  Final entropy: {entropy_array[-1].mean():.4f} ± {entropy_array[-1].std():.4f}")
        print(f"  Average reduction: {(entropy_array[0].mean() - entropy_array[-1].mean()):.4f}")
        print(f"  Number of resampling events: {len(detailed_logs['resampling_steps'])}")
        
        # Save detailed logs for visualization
        import pickle
        import os
        log_dir = "fk_steering_logs"
        os.makedirs(log_dir, exist_ok=True)
        log_file = os.path.join(log_dir, f"fk_steering_batch_{batch_idx + 1}.pkl")
        with open(log_file, 'wb') as f:
            pickle.dump(detailed_logs, f)
        print(f"  Detailed logs saved to: {log_file}")
    
    # Restore model state
    if model.ema:
        model.ema.restore(itertools.chain(
            model.backbone.parameters(),
            model.noise.parameters()))
    model.backbone.train()
    model.noise.train()
    
    # Print overall statistics
    print(f"\n=== Overall FK Steering Results ===")
    print(f'Generated {len(all_samples)} total samples')
    print(f'Generative perplexity: {model.gen_ppl_metric.compute().item():.4f}')
    
    # Print FK steering configuration
    print(f"\nFK Steering Configuration:")
    print(f"  Number of particles: {num_particles}")
    print(f"  Resample interval: {resample_interval}")
    print(f"  Lambda weight: {lambda_weight}")
    print(f"  Potential type: {potential_type}")
    print(f"  Total steps: {config.sampling.steps}")
    
    return all_samples


@hydra.main(version_base=None, config_path='configs', config_name='config')
def main(config):
    """Main entry point for FK steering sampling."""
    
    # Ensure FK steering configuration exists
    if not hasattr(config, 'fk_steering') or config.fk_steering is None:
        print("Warning: FK steering configuration not found. Using default values.")
        config.fk_steering = omegaconf.DictConfig({
            'num_particles': 8,
            'resample_interval': 50,
            'lambda_weight': 5.0,
            'potential_type': 'max'
        })
    
    # Validate FK steering parameters
    fk_config = config.fk_steering
    if fk_config.num_particles < 1:
        raise ValueError(f"num_particles must be >= 1, got {fk_config.num_particles}")
    if fk_config.resample_interval < 1:
        raise ValueError(f"resample_interval must be >= 1, got {fk_config.resample_interval}")
    if fk_config.lambda_weight <= 0:
        raise ValueError(f"lambda_weight must be > 0, got {fk_config.lambda_weight}")
    
    print(f"FK Steering Configuration:")
    print(f"  num_particles: {fk_config.num_particles}")
    print(f"  resample_interval: {fk_config.resample_interval}")
    print(f"  lambda_weight: {fk_config.lambda_weight}")
    print(f"  potential_type: {fk_config.potential_type}")
    
    L.seed_everything(config.seed)
    _print_config(config, resolve=True, save_cfg=True)
    
    logger = utils.get_logger(__name__)
    tokenizer = dataloader.get_tokenizer(config)
    
    if config.mode == 'fk_sample_eval':
        generate_samples_fk_steering(config, logger, tokenizer)
    else:
        raise ValueError(f"Mode {config.mode} not supported in FK steering. Use 'fk_sample_eval'.")


if __name__ == '__main__':
    main() 