from typing import Optional, Any, List
import os
import torch
from datasets import Dataset
import wandb
from transformers import TrainerCallback, PreTrainedModel, PreTrainedTokenizer
from dataclasses import dataclass
from data_utils import load_data

@dataclass
class EvaluationConfig:
    """Configuration for model evaluation"""
    max_new_tokens: int = 6
    max_retries: int = 1
    with_feedback: bool = False

class ModelEvaluator:
    """Handles model evaluation logic"""
    
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        config: Optional[EvaluationConfig] = None
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config or EvaluationConfig()
    
    def _generate(
        self,
        input_text: str,
        expected_response: Optional[str] = None,
        retry_count: int = 0
    ) -> str:
        """Generate text with optional feedback mechanism"""
        # Encode input
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            truncation=True,
            padding=True
        ).to(self.model.device)
        
        # Get input length for decoding
        input_len = inputs["input_ids"].shape[1]
        
        # Generate
        outputs = self.model.generate(
            inputs["input_ids"],
            max_new_tokens=self.config.max_new_tokens,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id
        )
        
        # Decode only new tokens
        generated = self.tokenizer.decode(
            outputs[0][input_len:],
            skip_special_tokens=True
        ).strip()
        
        if not self.config.with_feedback:
            return generated
            
        # Handle feedback mechanism
        if self._validate_generation(generated, expected_response):
            return f"{generated} 1"
            
        if retry_count < self.config.max_retries:
            return self._generate(
                f"{input_text} {generated} 0",
                expected_response,
                retry_count + 1
            )
            
        return f"{generated} 0"
    
    def _validate_generation(self, prediction: str, expected_response: str) -> bool:
        """Validate generated text against ground truth response"""
        prediction = prediction.rstrip("01").strip()
        expected_response = expected_response.strip()

        return prediction == expected_response
    
    def evaluate_sample(self, query: str, expected_response: str) -> bool:
        """Evaluate a single sample"""
        generated = self._generate(query, expected_response)
        
        return self._validate_generation(generated, expected_response)
    
    def evaluate_dataset(self, dataset: Dataset) -> float:
        """Evaluate entire dataset"""
        correct = 0
        total = len(dataset)
        
        self.model.eval()
        with torch.no_grad():
            for sample in dataset:
                if self.evaluate_sample(sample["query"], sample["response"]):
                    correct += 1
        
        return correct / total

class EvaluationCallback(TrainerCallback):
    """Callback for model evaluation during training"""
    
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        test_file_paths: List[str],
        eval_steps: int = 1000,
        logging_file_dir: Optional[str] = None,
        eval_config: Optional[EvaluationConfig] = None
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.test_file_paths = test_file_paths
        self.eval_steps = eval_steps
        self.logging_file_dir = logging_file_dir

        self.is_main_process = int(os.environ.get("RANK", 0)) == 0
        
        # Initialize evaluator
        self.evaluator = ModelEvaluator(
            model=model,
            tokenizer=tokenizer,
            config=eval_config or EvaluationConfig()
        )

        self.name = "with_feedback" if self.evaluator.config.with_feedback else "without_feedback"
        if self.logging_file_dir:
            self.logging_file_path = os.path.join(self.logging_file_dir, f"eval_{self.name}.log")
        else:
            self.logging_file_path = None
    
    def _log_predictions(self, dataset: Dataset, num_examples: int = 10) -> None:
        """Log predictions for a random subset of examples
        
        Args:
            dataset: The evaluation dataset
            num_examples: Number of examples to log (default: 10)
        """
        if not self.logging_file_path:
            return
            
        # Randomly sample indices
        num_examples = min(num_examples, len(dataset))
        
        with open(self.logging_file_path, "a") as f:
            for idx in range(num_examples):
                sample = dataset[idx]
                generated = self.evaluator._generate(sample["query"], sample["response"])
                f.write(
                    f"Example {idx}:\n"
                    f"Query: {sample['query']}\n"
                    f"Expected: {sample['response']}\n"
                    f"Generated: {generated.rstrip('01')}\n"
                    f"{'='*50}\n"
                )

    def evaluate(self, state) -> None:
        print(f"Running evaluation {self.name} at step {state.global_step}...")

        for test_file_path in self.test_file_paths:

            test_file_name = test_file_path.split("/")[-1].split(".")[0]
        
            # Load test data using imported function
            test_dataset = load_data(test_file_path)
            
            # Log predictions if needed
            if self.logging_file_path:
                self._log_predictions(test_dataset)
            
            # Calculate accuracy
            accuracy = self.evaluator.evaluate_dataset(test_dataset)
            
            # Log results
            print(
                f"Step {state.global_step}:" 
                f"Evaluation on {test_file_name} {self.name} "
                f"Accuracy: {accuracy:.2%}"
            )
            wandb.log({
                "step": state.global_step,
                f"test_accuracy_{test_file_name}_{self.name}": accuracy
            })
    
    def on_step_end(self, args: Any, state: Any, control: Any, **kwargs) -> None:
        """Evaluate model at specified intervals"""
        if (self.is_main_process and 
            state.global_step % self.eval_steps == 0 and 
            state.global_step > 0):
            self.evaluate(state)
            
    
    def on_train_end(self, args: Any, state: Any, control: Any, **kwargs) -> None:
        """Final evaluation at the end of training"""
        if self.is_main_process:
            self.evaluate(state)