import json
import os
import torch
import numpy as np
from typing import Dict, List, Optional, Any, Tuple
from collections import defaultdict
from pathlib import Path
import csv
from dataclasses import dataclass
from datetime import datetime

from ...extras.logging import get_logger

logger = get_logger(__name__)


@dataclass
class SampleMetrics:
    sample_id: str
    sample_type: str
    epoch: int
    step: int
    loss: float
    confidence: float 
    gradient_norm: Optional[float] = None
    content_hash: Optional[str] = None


class SampleMonitor:

    
    def __init__(self, output_dir: str, save_interval: int = 100):
        self.output_dir = Path(output_dir)
        #self.output_dir.mkdir(parents=True, exist_ok=True)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        print("output_dir: ", self.output_dir)
        self.save_interval = save_interval
        
        self.sample_metrics: List[SampleMetrics] = []
        
        self.type_metrics: Dict[str, List[SampleMetrics]] = defaultdict(list)
        
        self.current_epoch = 0
        self.current_step = 0
        
        self.metrics_file = self.output_dir / "sample_metrics.jsonl"
        self.summary_file = self.output_dir / "training_summary.json"
        
        logger.info(f"SampleMonitor initialized. Output dir: {self.output_dir}")
    
    def update_epoch_step(self, epoch: int, step: int):
        """更新当前epoch和step"""
        self.current_epoch = epoch
        self.current_step = step
    
    def record_sample_metrics(
        self,
        sample_ids: List[str],
        sample_types: List[str],
        losses: torch.Tensor,
        logits: torch.Tensor,
        labels: torch.Tensor,
        gradient_norms: Optional[List[float]] = None
    ):

        batch_size = len(sample_ids)
        

        confidences = self._calculate_confidence(logits, labels)

        for i in range(batch_size):
            sample_id = sample_ids[i]
            sample_type = sample_types[i]
            loss = losses[i].item() if isinstance(losses, torch.Tensor) else losses[i]
            confidence = confidences[i]
            gradient_norm = gradient_norms[i] if gradient_norms else None
            

            metrics = SampleMetrics(
                sample_id=sample_id,
                sample_type=sample_type,
                epoch=self.current_epoch,
                step=self.current_step,
                loss=loss,
                confidence=confidence,
                gradient_norm=gradient_norm,
                content_hash=self._hash_sample(sample_id)
            )

            self.sample_metrics.append(metrics)
            self.type_metrics[sample_type].append(metrics)
        

        if self.current_step % self.save_interval == 0:
            self._save_metrics_incremental()
    
    def _calculate_confidence(self, logits: torch.Tensor, labels: torch.Tensor) -> List[float]:
        """
        
        Args:
            logits: [batch_size, seq_len, vocab_size]
            labels: [batch_size, seq_len]
        
        Returns:
            List of confidence scores
        """
        confidences = []
        
        for i in range(logits.size(0)):
            sample_logits = logits[i]  # [seq_len, vocab_size]
            sample_labels = labels[i]  # [seq_len]
            

            valid_mask = sample_labels != -100
            if not valid_mask.any():
                confidences.append(0.0)
                continue
            

            probs = torch.softmax(sample_logits[valid_mask], dim=-1)
            target_labels = sample_labels[valid_mask]
            

            target_probs = probs.gather(1, target_labels.unsqueeze(-1)).squeeze(-1)
            avg_confidence = target_probs.mean().item()
            
            confidences.append(avg_confidence)
        
        return confidences
    
    def _hash_sample(self, sample_id: str) -> str:

        return str(hash(sample_id) % 1000000)
    
    def _save_metrics_incremental(self):
        start_idx = getattr(self, '_last_saved_idx', 0)
        new_metrics = self.sample_metrics[start_idx:]
        
        if not new_metrics:
            return
        
        jsonl_file = self.output_dir / "sample_metrics.jsonl"
        
        try:
            with open(jsonl_file, 'a', encoding='utf-8') as f:
                for metrics in new_metrics:
                    try:
                        record = {
                            'sample_id': metrics.sample_id,
                            'sample_type': metrics.sample_type,
                            'epoch': metrics.epoch,
                            'step': metrics.step,
                            'loss': round(metrics.loss, 6),
                            'confidence': round(metrics.confidence, 6),
                            'gradient_norm': round(metrics.gradient_norm, 6) if metrics.gradient_norm is not None else None,
                            'content_hash': metrics.content_hash,
                            'timestamp': datetime.now().isoformat()
                        }
                        f.write(json.dumps(record, ensure_ascii=False) + '\n')
                    except Exception as row_error:
                        logger.warning(f"Failed to write record for sample {metrics.sample_id}: {row_error}")
                        continue
                
                f.flush() 
            
            self._last_saved_idx = len(self.sample_metrics)
            logger.info(f"Saved {len(new_metrics)} sample metrics to JSONL")
            
        except Exception as e:
            logger.warning(f"Failed to save sample metrics: {e}")
    
    def get_type_statistics(self) -> Dict[str, Dict[str, float]]:
        stats = {}
        
        for sample_type, metrics_list in self.type_metrics.items():
            if not metrics_list:
                continue
            
            losses = [m.loss for m in metrics_list]
            confidences = [m.confidence for m in metrics_list]
            
            stats[sample_type] = {
                'count': len(metrics_list),
                'avg_loss': np.mean(losses),
                'std_loss': np.std(losses),
                'avg_confidence': np.mean(confidences),
                'std_confidence': np.std(confidences),
                'latest_avg_loss': np.mean(losses[-100:]) if len(losses) > 100 else np.mean(losses),
                'latest_avg_confidence': np.mean(confidences[-100:]) if len(confidences) > 100 else np.mean(confidences)
            }
        
        return stats
    
    def get_epoch_statistics(self, epoch: int) -> Dict[str, Dict[str, float]]:
        epoch_metrics = [m for m in self.sample_metrics if m.epoch == epoch]
        
        stats = defaultdict(lambda: {'losses': [], 'confidences': []})
        
        for metrics in epoch_metrics:
            stats[metrics.sample_type]['losses'].append(metrics.loss)
            stats[metrics.sample_type]['confidences'].append(metrics.confidence)
        
        result = {}
        for sample_type, data in stats.items():
            if data['losses']:
                result[sample_type] = {
                    'count': len(data['losses']),
                    'avg_loss': np.mean(data['losses']),
                    'std_loss': np.std(data['losses']),
                    'avg_confidence': np.mean(data['confidences']),
                    'std_confidence': np.std(data['confidences'])
                }
        
        return result
    
    def save_training_summary(self):
        self._save_metrics_incremental()
        
        summary = {
            'total_samples': len(self.sample_metrics),
            'total_epochs': max((m.epoch for m in self.sample_metrics), default=0) + 1,
            'total_steps': max((m.step for m in self.sample_metrics), default=0) + 1,
            'sample_types': list(self.type_metrics.keys()),
            'type_statistics': self.get_type_statistics(),
            'created_at': datetime.now().isoformat(),
            'files': {
                'sample_metrics': str(self.metrics_file),
                'training_summary': str(self.summary_file)
            }
        }
        
        with open(self.summary_file, 'w', encoding='utf-8') as f:
            json.dump(summary, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Training summary saved to {self.summary_file}")
        return summary
    
    def export_for_analysis(self, analysis_dir: str):
        analysis_path = Path(analysis_dir)
        analysis_path.mkdir(parents=True, exist_ok=True)
        
        for sample_type in self.type_metrics.keys():
            type_data = []
            for metrics in self.type_metrics[sample_type]:
                type_data.append({
                    'epoch': metrics.epoch,
                    'step': metrics.step,
                    'loss': metrics.loss,
                    'confidence': metrics.confidence
                })
            
            output_file = analysis_path / f"{sample_type}_metrics.json"
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(type_data, f, indent=2)
        
        logger.info(f"Analysis data exported to {analysis_path}")


def create_sample_monitor(output_dir: str, save_interval: int = 100) -> SampleMonitor:
    return SampleMonitor(output_dir, save_interval)