import os
import lightning as L
import torch

import json
from datetime import datetime

import utils
import tokenizers

import matplotlib.pyplot as plt 
import numpy as np
import seaborn as sns
from PIL import Image
import torchvision.transforms as transforms

from lightning.pytorch.loggers import CSVLogger
import logging
import models 
import transformers

import torch.distributed as dist
from datetime import timedelta

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import destroy_process_group 
import itertools

class StepByStepSamplingTester:
    def __init__(self, model, tokenizer, device='cuda'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        
        # CLIP normalization constants
        self.clip_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
        self.clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)
        
        # Initialize EOS tokens once
        self.eos_token_ids = self._get_eos_tokens()
        
    def denormalize_clip_image(self, img_tensor):
        # Denormalize CLIP-normalized image for visualization
        img_cpu = img_tensor.cpu()
        img_denorm = img_cpu * self.clip_std + self.clip_mean
        return torch.clamp(img_denorm, 0, 1)
        
    def _get_eos_tokens(self):
        # Get EOS token IDs to truncate at
        eos_tokens = []
        
        # Primary EOS token
        if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
            eos_tokens.append(self.tokenizer.eos_token_id)
        
        # Try to find period token
        try:
            period_tokens = self.tokenizer.encode(".", add_special_tokens=False)
            if period_tokens:
                eos_tokens.append(period_tokens[0])
        except:
            pass
            
        # Try common EOS patterns
        patterns = ["<|endoftext|>", "</s>"]
        for pattern in patterns:
            try:
                encoded = self.tokenizer.encode(pattern, add_special_tokens=False)
                if encoded:
                    eos_tokens.append(encoded[0])
            except:
                continue
        
        # Remove duplicates and return
        return list(set(eos_tokens))
    
    def _truncate_at_first_eos(self, tokens, text):
        # Truncate tokens and text at first EOS token, (only truncate if we have meaningful content)
        if not text or len(text.strip()) == 0:
            return tokens, text, None
            
        for i, token_id in enumerate(tokens):
            if token_id in self.eos_token_ids:
                # Only truncate if there's content before the EOS token
                if i > 2:  # Don't truncate if EOS is the first token
                    truncated_tokens = tokens[:i]  # Don't include the EOS token
                    # Re-decode the truncated tokens
                    truncated_text = self.tokenizer.decode(truncated_tokens, skip_special_tokens=True).strip()
                        
                    if len(truncated_text.split()) >= 2:  # Require at least 2 words
                        return truncated_tokens, truncated_text, i
        
        # No meaningful EOS found or truncation would result in empty text, return original
        return tokens, text, None

    @torch.no_grad()
    def generate_caption_with_steps(self, image_tensor=None, num_steps=25, stop_at_eos=True):
        # Generate caption with detailed step-by-step logging
        self.model.eval()
        
        # Extract image features if available
        image_features = None
        if image_tensor is not None and self.model.use_image_conditioning and self.model.image_encoder is not None:
            image_features = self.model.image_encoder(image_tensor.to(self.device))
            print(f"Image features extracted: shape={image_features.shape}, norm={image_features.norm():.4f}")
        
        batch_size = 1
        seq_len = self.model.config.model.length
        
        # Start from pure noise (all MASK tokens)
        x = self.model._sample_prior(batch_size, seq_len).to(self.device)
        
        # Initialize step tracking
        step_data = []
        eps = 1e-3
       
        # Create timestep schedule
        timesteps_lin = torch.linspace(1.0, eps, num_steps + 1, device=self.device)
        dt = (1.0 - eps) / num_steps
        
        # Track sampling progress
        for i in range(num_steps):
            t_current = timesteps_lin[i] * torch.ones(x.shape[0], device=self.device)
            t_for_update = t_current.unsqueeze(-1)
            
            # Get current state info
            num_masks = (x == self.model.mask_index).sum().item()
            num_revealed = seq_len - num_masks
            
            # Check for early stopping if EOS found
            if stop_at_eos and num_revealed > 0:
                current_tokens = x[0].cpu().tolist()
                for j, token_id in enumerate(current_tokens):
                    if token_id in self.eos_token_ids:
                        print(f"Early stopping at step {i+1}: EOS token found at position {j}")
                        break
            
            # Get model predictions at this step
            sigma_t, _ = self.model.noise(t_current)
            model_output = self.model.forward(x, sigma_t, image_features=image_features)
            
            # Get top predicted tokens for masked positions
            masked_positions = (x == self.model.mask_index)
            if masked_positions.any():
                masked_probs = torch.softmax(model_output[masked_positions], dim=-1)
                top_tokens = torch.topk(masked_probs, k=5, dim=-1)
                
                # Convert to readable tokens
                top_token_texts = []
                for pos_idx in range(min(3, masked_positions.sum())):  # Show first 3 positions
                    pos_tokens = [self.tokenizer.decode([tok.item()]) for tok in top_tokens.indices[pos_idx]]
                    pos_probs = [prob.item() for prob in top_tokens.values[pos_idx]]
                    top_token_texts.append(list(zip(pos_tokens, pos_probs)))
            else:
                top_token_texts = []
            
            # Perform sampling step
            if self.model.sampler == 'analytic':
                x_new = self.model._analytic_update(x, t_for_update, dt, image_features)
            elif self.model.sampler == 'ddpm_cache':
                _, x_new = self.model._ddpm_caching_update(x, t_for_update, dt, image_features, None)
            else:
                x_new = self.model._ddpm_update(x, t_for_update, dt, image_features)
            
            # Track what changed
            changes = (x != x_new).sum().item()
            
            # Create step record with current text (truncated at EOS if found)
            current_text = self.tokenizer.decode(x[0], skip_special_tokens=True)
            if stop_at_eos and current_text and len(current_text.strip()) > 0:
                _, truncated_text, _ = self._truncate_at_first_eos(x[0].cpu().tolist(), current_text)
                if truncated_text:  # Only use truncated version if it's not empty
                    current_text = truncated_text
            
            step_record = {
                'step': i + 1,
                'timestep': t_current.item(),
                'sigma': sigma_t.item(),
                'num_masks': num_masks,
                'num_revealed': num_revealed,
                'changes_made': changes,
                'current_tokens': x.cpu().tolist()[0],
                'current_text': current_text,
                'top_predictions': top_token_texts
            }
            
            step_data.append(step_record)
            
            # Print progress every few steps
            if (i + 1) % max(1, num_steps // 5) == 0 or i == 0:
                print(f"Step {i+1:3d}/{num_steps}: {num_masks:2d} masks left, text: '{current_text}'")
            
            x = x_new
        
        # Final decoding with EOS truncation
        final_caption = self.tokenizer.decode(x[0], skip_special_tokens=True).strip()
        if stop_at_eos and final_caption and len(final_caption.strip()) > 0:
            truncated_tokens, truncated_caption, eos_pos = self._truncate_at_first_eos(x[0].cpu().tolist(), final_caption)
            if eos_pos is not None and truncated_caption:
                print(f"Final caption truncated at EOS position {eos_pos}")
                final_caption = truncated_caption
                x = torch.tensor([truncated_tokens], device=x.device)
        
        return final_caption, x[0], step_data
    
    def create_sampling_visualization(self, step_data, image_tensor, final_caption, output_dir):
        # Create figure with multiple subplots
        fig = plt.figure(figsize=(20, 12))
        gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
        
        # 1. Image (if available)
        if image_tensor is not None:
            ax_img = fig.add_subplot(gs[0, 0])
            img_display = self.denormalize_clip_image(image_tensor[0])
            ax_img.imshow(img_display.permute(1, 2, 0))
            ax_img.set_title('Input Image', fontweight='bold')
            ax_img.axis('off')
        
        # 2. Progress over time
        ax_progress = fig.add_subplot(gs[0, 1:])
        steps = [s['step'] for s in step_data]
        num_masks = [s['num_masks'] for s in step_data]
        changes = [s['changes_made'] for s in step_data]
        
        ax_progress.plot(steps, num_masks, 'b-o', label='MASK tokens remaining', markersize=3)
        ax_progress2 = ax_progress.twinx()
        ax_progress2.plot(steps, changes, 'r-s', label='Changes made', markersize=3)
        
        ax_progress.set_xlabel('Sampling Step')
        ax_progress.set_ylabel('MASK Tokens Remaining', color='b')
        ax_progress2.set_ylabel('Changes Made', color='r')
        ax_progress.set_title('Sampling Progress')
        ax_progress.grid(True, alpha=0.3)
        
        # 3. Text evolution
        ax_text = fig.add_subplot(gs[1, :])
        text_evolution = []
        for i, step in enumerate(step_data[::max(1, len(step_data)//10)]):  # Sample 10 steps
            text_evolution.append(f"Step {step['step']:2d}: {step['current_text']}")
        
        text_display = '\n'.join(text_evolution)
        text_display += f"\n\nFINAL: {final_caption}"
        
        ax_text.text(0.05, 0.95, text_display, transform=ax_text.transAxes,
                    verticalalignment='top', fontsize=10, fontfamily='monospace',
                    bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.7))
        ax_text.set_title('Text Evolution During Sampling')
        ax_text.axis('off')        
        
        # Save the visualization
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        plt.savefig(f"{output_dir}/sampling_steps_{timestamp}.png", dpi=300, bbox_inches='tight')
        plt.close()
        
        return f"sampling_steps_{timestamp}.png"
    
    def save_detailed_json(self, step_data, image_tensor, final_caption, ground_truth, output_dir):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        detailed_log = {
            'timestamp': timestamp,
            'final_caption': final_caption,
            'ground_truth': ground_truth,
            'model_config': {
                'sequence_length': self.model.config.model.length,
                'vocab_size': self.model.vocab_size,
                'mask_index': self.model.mask_index,
                'sampler_type': self.model.sampler,
                'parameterization': self.model.parameterization
            },
            'sampling_steps': step_data,
            'analysis': {
                'total_steps': len(step_data),
                'initial_masks': step_data[0]['num_masks'] if step_data else 0,
                'final_masks': step_data[-1]['num_masks'] if step_data else 0,
                'tokens_revealed': step_data[0]['num_masks'] - step_data[-1]['num_masks'] if step_data else 0,
                'total_changes': sum(s['changes_made'] for s in step_data),
                'final_text_length': len(final_caption.split()) if final_caption else 0
            }
        }
        
        # Add step by step text evolution
        detailed_log['text_evolution'] = []
        for i, step in enumerate(step_data):
            if i % max(1, len(step_data) // 20) == 0:  # Sample every 20th step
                detailed_log['text_evolution'].append({
                    'step': step['step'],
                    'text': step['current_text'],
                    'masks_remaining': step['num_masks']
                })
        
        # Save to JSON
        json_path = f"{output_dir}/detailed_sampling_{timestamp}.json"
        with open(json_path, 'w') as f:
            json.dump(detailed_log, f, indent=2)
        
        return json_path
    
    def test_with_step_visualization(self, real_image, ground_truth_caption=None, num_steps=25, stop_at_eos=True):
        # Run the complete test with step by step visualization        
        output_dir = "step_by_step_sampling_results"
        os.makedirs(output_dir, exist_ok=True)
        
        print("="*60)
        print("STEP-BY-STEP SAMPLING TEST")
        print("="*60)
        
        # Test 1: With image
        print("\n1. GENERATING WITH IMAGE:")
        caption_with_image, tokens_with, steps_with = self.generate_caption_with_steps(
            real_image, num_steps=num_steps, stop_at_eos=stop_at_eos
        )
        
        # Create detailed visualization for 'with image'
        viz_file_with = self.create_sampling_visualization(
            steps_with, real_image, caption_with_image, output_dir
        )
        
        # Save detailed JSON for 'with image'
        json_file_with = self.save_detailed_json(
            steps_with, real_image, caption_with_image, ground_truth_caption, output_dir
        )
        
        print("\n2. GENERATING WITHOUT IMAGE:")
        caption_without_image, tokens_without, steps_without = self.generate_caption_with_steps(
            None, num_steps=num_steps, stop_at_eos=stop_at_eos
        )
        
        # Create visualization for 'without image'
        viz_file_without = self.create_sampling_visualization(
            steps_without, None, caption_without_image, output_dir
        )
        
        # Save detailed JSON for 'without image'  
        json_file_without = self.save_detailed_json(
            steps_without, None, caption_without_image, ground_truth_caption, output_dir
        )
        
        # Create comparison summary
        self.create_comparison_summary(
            steps_with, steps_without, caption_with_image, caption_without_image,
            ground_truth_caption, output_dir
        )
        
        print("\n" + "="*60)
        print("RESULTS SUMMARY:")
        print("="*60)
        print(f"With image:    '{caption_with_image}'")
        print(f"Without image: '{caption_without_image}'")
        print(f"Ground truth:  '{ground_truth_caption}'")
        print(f"\nFiles saved to: {output_dir}/")
        print(f"- Visualization (with image): {viz_file_with}")
        print(f"- Detailed JSON (with image): {json_file_with}")
        print(f"- Visualization (without image): {viz_file_without}")  
        print(f"- Detailed JSON (without image): {json_file_without}")
        
        return {
            'with_image': {
                'caption': caption_with_image,
                'tokens': tokens_with,
                'steps': steps_with
            },
            'without_image': {
                'caption': caption_without_image,
                'tokens': tokens_without,
                'steps': steps_without
            },
            'files': {
                'viz_with': viz_file_with,
                'json_with': json_file_with,
                'viz_without': viz_file_without,
                'json_without': json_file_without
            }
        }
    
    def create_comparison_summary(self, steps_with, steps_without, caption_with, 
                                caption_without, ground_truth, output_dir):
        # Create a summary comparing both sampling runs
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Calculate basic metrics
        with_tokens = len(caption_with.split()) if caption_with else 0
        without_tokens = len(caption_without.split()) if caption_without else 0
        
        # Token overlap
        if caption_with and caption_without:
            with_words = set(caption_with.lower().split())
            without_words = set(caption_without.lower().split())
            overlap = len(with_words & without_words) / len(with_words | without_words) if (with_words | without_words) else 0
        else:
            overlap = 0
        
        comparison_data = {
            'timestamp': timestamp,
            'ground_truth': ground_truth,
            'results': {
                'with_image': {
                    'caption': caption_with,
                    'word_count': with_tokens,
                    'total_steps': len(steps_with),
                    'final_masks': steps_with[-1]['num_masks'] if steps_with else 0
                },
                'without_image': {
                    'caption': caption_without,
                    'word_count': without_tokens,
                    'total_steps': len(steps_without),
                    'final_masks': steps_without[-1]['num_masks'] if steps_without else 0
                }
            },
            'comparison': {
                'word_overlap_ratio': overlap,
                'conditioning_effective': overlap < 0.6,
                'word_count_difference': abs(with_tokens - without_tokens)
            }
        }
        
        # Save comparison
        with open(f"{output_dir}/comparison_summary_{timestamp}.json", 'w') as f:
            json.dump(comparison_data, f, indent=2)
        
        return comparison_data

    # Keep the existing methods for backward compatibility
    @torch.no_grad()
    def generate_caption_with_eos_truncation(self, image_tensor=None, num_steps=25):
        # Generate caption using existing logic but truncate at first EOS        
        full_caption, full_tokens, step_data = self.generate_caption_with_steps(
            image_tensor, num_steps, stop_at_eos=True
        )
        
        print(f"\nFull generation: '{full_caption}'")
        print(f"Full tokens: {full_tokens.tolist()}")
        
        # Truncate at first EOS
        truncated_tokens, truncated_caption, eos_pos = self._truncate_at_first_eos(
            full_tokens.tolist(), full_caption
        )
        
        if eos_pos is not None:
            print(f"Found EOS at position {eos_pos}, truncating...")
            print(f"Truncated caption: '{truncated_caption}'")
            return truncated_caption, torch.tensor(truncated_tokens), step_data, True, eos_pos
        else:
            print("No EOS token found, using full generation")
            return full_caption, full_tokens, step_data, False, None
