import itertools
import math
import os
import typing
from dataclasses import dataclass
from pathlib import Path

import hydra.utils
import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import transformers
import torchvision.models as vision_models

import dataloader
import models
import noise_schedule
import utils
import logging
import hashlib
from typing import Dict, Tuple, Iterable

from logging.handlers import RotatingFileHandler

from transformers import CLIPVisionModel, CLIPProcessor

import csv
import matplotlib.pyplot as plt
from PIL import Image as PILImage
import json
from datetime import datetime

import torchvision.transforms as transforms
from ImageEncoder import ImageEncoder
import models

def _unsqueeze(x, reference):
    return x.view(*x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
    
def _sample_categorical(categorical_probs):
    gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
    return (categorical_probs / gumbel_norm).argmax(dim=-1)

def tensor_sha(t: torch.Tensor, max_bytes: int = 1024) -> str:
        b = t.detach().cpu().numpy().tobytes()[:max_bytes]
        return hashlib.sha1(b).hexdigest()

def load_hf_into_dit_and_report(
	dit: nn.Module,
	hf_state: Dict[str, torch.Tensor],
	ignore_prefixes: Iterable[str] = ("image_adapter.","image_embedder.", "image_encoder."),
	try_map_head_and_embed: bool = True,
	logger: typing.Optional[logging.Logger] = None,
    ) -> Dict[str, typing.Any]:
        log = logger.info if logger else print
        own = dit.state_dict()

        if try_map_head_and_embed:
            for cand in ["lm_head.weight", "transformer.lm_head.weight"]:
                if cand in hf_state and "output_layer.linear.weight" in own:
                    if hf_state[cand].shape == own["output_layer.linear.weight"].shape:
                        hf_state["output_layer.linear.weight"] = hf_state[cand]
                        
            for cand in ["transformer.wte.weight", "wte.weight", "embeddings.word_embeddings.weight"]:
                if cand in hf_state and "vocab_embed.embedding" in own:
                    if hf_state[cand].shape == own["vocab_embed.embedding"].shape:
                        hf_state["vocab_embed.embedding"] = hf_state[cand]
                        
                elif cand in hf_state and "vocab_embed.embedding" in own:
                    if hf_state[cand].shape == own["vocab_embed.embedding"].shape:
                        hf_state["vocab_embed.embedding"] = hf_state[cand]

        # Build ignore set
        def _ignored(k: str) -> bool:
            return any(k.startswith(pref) for pref in ignore_prefixes)

        matched, mismatched, missing_in_own, unexpected = [], [], [], []
        before_after_hash = []

        # What it has that we don't
        for k, v in hf_state.items():
            if _ignored(k):
                continue
                
            if k in own:
                if own[k].shape == v.shape:
                    # Record hash to prove it actually changed
                    h_before = tensor_sha(own[k])
                    own[k].copy_(v)
                    h_after  = tensor_sha(own[k])
                    matched.append(k)
                    before_after_hash.append((k, h_before, h_after, h_before != h_after))
                else:
                    mismatched.append((k, tuple(v.shape), tuple(own[k].shape)))
            else:
                unexpected.append(k)

        # What we have
        for k in own.keys():
            if _ignored(k):
                continue
            if k not in hf_state:
                missing_in_own.append(k)

        # Load modified state
        dit.load_state_dict(own)

        # Prove some tensors changed content
        changed = sum(1 for _, hb, ha, diff in before_after_hash if diff)

        return {
            "matched": matched,
            "mismatched": mismatched,
            "missing_in_own": missing_in_own,
            "unexpected_in_hf": unexpected,
            "changed_hash_count": changed,
            "hash_records": before_after_hash,
        }

def copy_dit_block_weights(source_block, target_block, logger):
    source_state = source_block.state_dict()
    target_state = target_block.state_dict()
        
    copied_keys = []
    for key, param in source_state.items():
        if key in target_state and target_state[key].shape == param.shape:
            target_state[key].copy_(param)
            copied_keys.append(key)
        
    target_block.load_state_dict(target_state, strict=False)
    return copied_keys
                
def cross_attention_enhancement(use_image_conditioning, dit_image_feature_dim, config, logger, backbone): 
    if use_image_conditioning and dit_image_feature_dim is not None:
        # Choose which layers get cross-attention
        n_blocks = config.model.n_blocks
        cross_attn_layer_indices = list(range(2, n_blocks, 2))
        
        logger.info(f"Adding cross-attention to layers: {cross_attn_layer_indices}")
        
        # Replace selected blocks with enhanced versions
        for idx in cross_attn_layer_indices:
            if idx < len(backbone.blocks):
                logger.info(f"Enhancing layer {idx} with cross-attention")
                
                # Create enhanced block
                enhanced_block = models.dit_v2.DDiTBlockWithCrossAttn(
                    dim=config.model.hidden_size,
                    n_heads=config.model.n_heads,
                    cond_dim=config.model.cond_dim,
                    mlp_ratio=getattr(config.model, "mlp_ratio", 4.0),
                    dropout=config.model.dropout,
                    use_cross_attn=True,
                    image_feature_dim=dit_image_feature_dim
                )
                
                # Copy weights from original block to enhanced block
                copy_dit_block_weights(backbone.blocks[idx], enhanced_block, logger)
                
                # Replace the block
                backbone.blocks[idx] = enhanced_block
    return backbone

def save_loss_history(train_loss_history, val_loss_history, output_dir="loss_logs"):
    os.makedirs(output_dir, exist_ok=True)
    train_path = os.path.join(output_dir, "train_loss.csv")
    val_path = os.path.join(output_dir, "val_loss.csv")

    with open(train_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["step", "loss"])
        for i, loss in enumerate(train_loss_history):
            writer.writerow([i, loss])

    with open(val_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["step", "loss"])
        for i, loss in enumerate(val_loss_history):
            writer.writerow([i, loss])  
            
        
def log_generation_quality(self, generated_texts, ground_truths, log):        
    # Text length statistics (should be reasonable for captions)
    lengths = [len(text.split()) for text in generated_texts]
    avg_length = sum(lengths) / len(lengths)
        
    # Vocabulary diversity
    all_words = set()
    for text in generated_texts:
        all_words.update(text.lower().split())
    vocab_diversity = len(all_words)
        
    # Simple word overlap with ground truth
    overlaps = []
    for gen, gt in zip(generated_texts, ground_truths):
        gen_words = set(gen.lower().split())
        gt_words = set(gt.lower().split())
        overlap = len(gen_words & gt_words) / max(len(gen_words | gt_words), 1)
        overlaps.append(overlap)
        
    avg_overlap = sum(overlaps) / len(overlaps)
        
    # Log metrics that correlate with caption quality
    log('val/caption_length', avg_length, on_epoch=True)
    log('val/vocab_diversity', vocab_diversity, on_epoch=True)
    log('val/word_overlap', avg_overlap, on_epoch=True)
        
