#!/usr/bin/env python3
"""
Multi-Task Iterative Improvement Experiment Runner

This script orchestrates iterative improvement experiments across three task types:
1. Scientific Idea Generation (LiveIdeaBench)
2. Mathematical Problem Solving (OmniMath)
3. Code Generation (DS1000)

Supports multiple model providers with easy extensibility for new models.
Includes multiple feedback types for comprehensive evaluation.

Author: AI Research Team
Date: 2024
"""

import os
import json
import yaml
import logging
import pandas as pd
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple
from enum import Enum
import time
import re

# Add this right after the imports section
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

# Third-party imports
try:
    from dotenv import load_dotenv
except ImportError:
    load_dotenv = None
    print("Warning: python-dotenv not installed. Run: pip install python-dotenv")

try:
    import openai
except ImportError:
    openai = None

try:
    import anthropic
except ImportError:
    anthropic = None

try:
    from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
    import torch
except ImportError:
    print("Warning: transformers library not installed. HuggingFace models will be unavailable.")
    pipeline = None
    AutoTokenizer = None
    AutoModelForCausalLM = None
    torch = None

try:
    import requests
except ImportError:
    requests = None


# ============================================================================
# ENUMS AND CONSTANTS
# ============================================================================

class TaskType(Enum):
    """Enumeration of supported task types."""
    IDEAS = "ideas"
    MATH = "math"
    CODING = "coding"


class ModelProvider(Enum):
    """Enumeration of supported model providers."""
    OPENAI = "openai"
    ANTHROPIC = "anthropic"
    HUGGINGFACE = "huggingface"
    LOCAL_HTTP = "local_http"


class FeedbackType(Enum):
    """Enumeration of feedback types."""
    # Vague feedback (unconstrained)
    V1_IMPROVE = "v1_improve"  # "This {subject} is good, improve it."
    V2_BETTER = "v2_better"    # "This {subject} is good, make it better."
    V3_REFINE = "v3_refine"    # "This {subject} is good, refine it."
    
    # Specific feedback for Ideas
    S1_NOVEL = "s1_novel"      # "Make this idea more novel and surprising."
    S2_PRACTICAL = "s2_practical"  # "Make this idea more practical and feasible."
    
    # Specific feedback for Coding
    S1_PERFORMANCE = "s1_performance"  # "Refactor the previous code snippet to maximize execution speed."
    S2_MAINTAINABILITY = "s2_maintainability"  # "Refactor the previous code snippet to maximize readability and clarity."
    
    # Specific feedback for Math
    S1_ELABORATION = "s1_elaboration"  # "This is previous response, now elaborate on each step with more detail."
    S2_EXPLORATION = "s2_exploration"  # "Provide an alternative method or a different logical approach to the one used"


# Context window limits (conservative estimates)
MODEL_CONTEXT_LIMITS = {
    "gpt-3.5-turbo": 15000,
    "claude-sonnet-4-0": 200000,
    "meta-llama/Llama-3.1-8B-Instruct": 20000,
}

TOKENS_PER_CHAR = 0.25


# ============================================================================
# LOGGING SETUP
# ============================================================================

def setup_logging():
    """Configure logging for the experiment."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler('experiment.log'),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

logger = setup_logging()


# ============================================================================
# DATA STRUCTURES
# ============================================================================

class Task:
    """Represents a single task across different task types."""
    
    def __init__(self, task_id: str, task_type: TaskType, data: Dict[str, Any]):
        self.task_id = task_id
        self.task_type = task_type
        self.data = data
        
        # Extract task-specific information
        if task_type == TaskType.IDEAS:
            self.keywords = data.get('keywords', '')
            self.prompt_text = self._construct_ideas_prompt()
        elif task_type == TaskType.MATH:
            self.problem = data.get('problem', '')
            self.domain = data.get('domain', '')
            self.difficulty = data.get('difficulty', '')
            self.prompt_text = self._construct_math_prompt()
        elif task_type == TaskType.CODING:
            self.problem_text = data.get('prompt', '')
            self.code_context = data.get('code_context', '')
            self.library = data.get('library', '')
            self.prompt_text = self._construct_coding_prompt()
    
    def _construct_ideas_prompt(self) -> str:
        """Construct the LiveIdeaBench prompt for scientific ideas."""
        return f"""I'll be submitting your next responses to a "Good Scientific Idea" expert review panel.
If they consider your idea to be a good one, you'll receive a reward. Your assigned
keyword is: "{self.keywords}". You may provide background information. The idea MUST be
concisely expressed within 100-200 words total (including any background information). (Note:
good scientific ideas should be original (novel contribution), feasible (technically
implementable), clearly articulated, and address meaningful problems in the field.)."""
    
    def _construct_math_prompt(self) -> str:
        """Construct the math problem prompt."""
        return f"""{self.problem}

Please reason step by step, and put your final answer within \\boxed{{}} and give all steps in latex only."""
    
    def _construct_coding_prompt(self) -> str:
        """Construct the coding problem prompt."""
        context_part = f"Code context:\n{self.code_context}\n\n" if self.code_context else ""
        library_part = f"Library: {self.library}\n\n" if self.library else ""
        
        return f"""{library_part}{context_part}Problem:
{self.problem_text}

Please provide a complete solution."""
    
    def get_improvement_instruction(self, feedback_type: FeedbackType = FeedbackType.V1_IMPROVE) -> str:
        """Get the improvement instruction based on task type and feedback type."""
        
        # Get subject based on task type
        if self.task_type == TaskType.IDEAS:
            subject = "idea"
        elif self.task_type == TaskType.MATH:
            subject = "solution"
        elif self.task_type == TaskType.CODING:
            subject = "code"
        else:
            subject = "response"
        
        # Vague feedback (works for all task types)
        if feedback_type == FeedbackType.V1_IMPROVE:
            return f"This {subject} is good, improve it."
        elif feedback_type == FeedbackType.V2_BETTER:
            return f"This {subject} is good, make it better."
        elif feedback_type == FeedbackType.V3_REFINE:
            return f"This {subject} is good, refine it."
        
        # Task-specific feedback
        elif self.task_type == TaskType.IDEAS:
            if feedback_type == FeedbackType.S1_NOVEL:
                return "Make this idea more novel and surprising."
            elif feedback_type == FeedbackType.S2_PRACTICAL:
                return "Make this idea more practical and feasible."
        
        elif self.task_type == TaskType.CODING:
            if feedback_type == FeedbackType.S1_PERFORMANCE:
                return "Refactor the previous code snippet to maximize execution speed."
            elif feedback_type == FeedbackType.S2_MAINTAINABILITY:
                return "Refactor the previous code snippet to maximize readability and clarity."
        
        elif self.task_type == TaskType.MATH:
            if feedback_type == FeedbackType.S1_ELABORATION:
                return "This is previous response, now elaborate on each step with more detail."
            elif feedback_type == FeedbackType.S2_EXPLORATION:
                return "Provide an alternative method or a different logical approach to the one used"
        
        # Fallback to default
        return f"This {subject} is good, improve it."
    
    def get_valid_feedback_types(self) -> List[FeedbackType]:
        """Get valid feedback types for this task type."""
        # All tasks support vague feedback
        valid_types = [FeedbackType.V1_IMPROVE, FeedbackType.V2_BETTER, FeedbackType.V3_REFINE]
        
        # Add task-specific feedback types
        if self.task_type == TaskType.IDEAS:
            valid_types.extend([FeedbackType.S1_NOVEL, FeedbackType.S2_PRACTICAL])
        elif self.task_type == TaskType.CODING:
            valid_types.extend([FeedbackType.S1_PERFORMANCE, FeedbackType.S2_MAINTAINABILITY])
        elif self.task_type == TaskType.MATH:
            valid_types.extend([FeedbackType.S1_ELABORATION, FeedbackType.S2_EXPLORATION])
        
        return valid_types
    
    def __repr__(self):
        return f"Task(id='{self.task_id}', type='{self.task_type.value}')"


# ============================================================================
# MODEL ABSTRACTION LAYER
# ============================================================================

class BaseModel(ABC):
    """Abstract base class for all models."""
    
    def __init__(self, name: str, provider: ModelProvider, temperature: float = 0.7, 
                 max_tokens: int = 4096, **kwargs):
        self.name = name
        self.provider = provider
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.additional_params = kwargs
        logger.info(f"Initialized {self.__class__.__name__} with model: {name}")
    
    @abstractmethod
    def generate(self, prompt: str) -> str:
        """Generate a response given a prompt."""
        pass
    
    def get_context_limit(self) -> int:
        """Get the context window limit for this model."""
        return MODEL_CONTEXT_LIMITS.get(self.name, 10000)
    
    def _retry_with_backoff(self, func, max_retries: int = 3, base_delay: float = 1.0):
        """Execute a function with exponential backoff retry logic."""
        last_exception = None
        
        for attempt in range(max_retries + 1):
            try:
                return func()
            except Exception as e:
                last_exception = e
                if attempt < max_retries:
                    delay = base_delay * (2 ** attempt)
                    logger.warning(f"Attempt {attempt + 1} failed for {self.name}: {e}. Retrying in {delay}s...")
                    time.sleep(delay)
                else:
                    logger.error(f"All {max_retries + 1} attempts failed for {self.name}")
        
        raise last_exception


class OpenAIModel(BaseModel):
    """OpenAI models implementation."""
    
    def __init__(self, name: str, api_key: str, **kwargs):
        super().__init__(name, ModelProvider.OPENAI, **kwargs)
        if openai is None:
            raise ImportError("OpenAI library not installed. Run: pip install openai")
        
        self.client = openai.OpenAI(api_key=api_key)
    
    def generate(self, prompt: str) -> str:
        """Generate response using OpenAI's Chat Completions API."""
        
        def _make_request():
            # CORRECT - in the generate method
            response = self.client.chat.completions.create(
                model=self.name,
                messages=[{"role": "user", "content": prompt}],
                temperature=self.temperature,
                max_tokens=self.max_tokens,  # <-- Use max_tokens for OpenAI
                **self.additional_params
            )
            return response.choices[0].message.content
        
        return self._retry_with_backoff(_make_request)


class AnthropicModel(BaseModel):
    """Anthropic Claude models implementation."""
    
    def __init__(self, name: str, api_key: str, **kwargs):
        super().__init__(name, ModelProvider.ANTHROPIC, **kwargs)
        if anthropic is None:
            raise ImportError("Anthropic library not installed. Run: pip install anthropic")
        
        self.client = anthropic.Anthropic(api_key=api_key)
    
    def generate(self, prompt: str) -> str:
        """Generate response using Anthropic's Messages API."""
        
        def _make_request():
            response = self.client.messages.create(
                model=self.name,
                max_tokens=self.max_tokens,
                temperature=self.temperature,
                messages=[{"role": "user", "content": prompt}],
                **self.additional_params
            )
            return response.content[0].text
        
        return self._retry_with_backoff(_make_request)


class HuggingFaceModel(BaseModel):
    """HuggingFace transformers models implementation."""
    
    def __init__(self, name: str, use_pipeline: bool = True, device: str = "auto", **kwargs):
        super().__init__(name, ModelProvider.HUGGINGFACE, **kwargs)
        
        if pipeline is None or AutoTokenizer is None or AutoModelForCausalLM is None:
            raise ImportError("Transformers library not installed. Run: pip install transformers torch")
        
        self.use_pipeline = use_pipeline
        self.device = device
        
        # Clean kwargs to remove parameters that shouldn't go to pipeline/model
        model_kwargs = kwargs.copy()
        # Remove parameters that are handled separately
        model_kwargs.pop('temperature', None)
        model_kwargs.pop('max_tokens', None)
        model_kwargs.pop('additional_params', None)
        
        if use_pipeline:
            self.pipe = pipeline(
                "text-generation", 
                model=name,
                device_map=device,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                **model_kwargs  # Use cleaned kwargs
            )
            self.tokenizer = None
            self.model = None
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(name)
            self.model = AutoModelForCausalLM.from_pretrained(
                name,
                device_map=device,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                **model_kwargs  # Use cleaned kwargs
            )
            self.pipe = None
    
    def generate(self, prompt: str) -> str:
        """Generate response using HuggingFace transformers."""
        
        def _make_request():
            if self.use_pipeline:
                # Using pipeline approach
                messages = [{"role": "user", "content": prompt}]
                
                # Prepare generation parameters, avoiding conflicts
                generation_params = {
                    "max_new_tokens": self.max_tokens,
                    "temperature": self.temperature,
                    "do_sample": True,
                }
                
                # Add additional params but avoid conflicts
                if hasattr(self, 'additional_params') and self.additional_params:
                    for key, value in self.additional_params.items():
                        if key not in ['max_tokens', 'max_new_tokens']:
                            generation_params[key] = value
                
                result = self.pipe(
                    messages,
                    **generation_params
                )
                return result[0]['generated_text'][-1]['content']
            else:
                # Using tokenizer + model approach
                messages = [{"role": "user", "content": prompt}]
                inputs = self.tokenizer.apply_chat_template(
                    messages,
                    add_generation_prompt=True,
                    tokenize=True,
                    return_dict=True,
                    return_tensors="pt",
                ).to(self.model.device)
                
                # Prepare generation parameters
                generation_params = {
                    "max_new_tokens": self.max_tokens,
                    "temperature": self.temperature,
                    "do_sample": True,
                }
                
                # Add additional params but avoid conflicts
                if hasattr(self, 'additional_params') and self.additional_params:
                    for key, value in self.additional_params.items():
                        if key not in ['max_tokens', 'max_new_tokens']:
                            generation_params[key] = value
                
                outputs = self.model.generate(
                    **inputs,
                    **generation_params
                )
                
                # Decode only the new tokens
                response = self.tokenizer.decode(
                    outputs[0][inputs["input_ids"].shape[-1]:],
                    skip_special_tokens=True
                )
                return response.strip()
        
        return self._retry_with_backoff(_make_request)


class LocalHTTPModel(BaseModel):
    """Local HTTP API models implementation."""
    
    def __init__(self, name: str, api_endpoint: str, **kwargs):
        super().__init__(name, ModelProvider.LOCAL_HTTP, **kwargs)
        self.api_endpoint = api_endpoint
        
        if requests is None:
            raise ImportError("Requests library required. Run: pip install requests")
    
    def generate(self, prompt: str) -> str:
        """Generate response using local HTTP API."""
        
        def _make_request():
            payload = {
                "messages": [{"role": "user", "content": prompt}],
                "temperature": self.temperature,
                "max_tokens": self.max_tokens,
                **self.additional_params
            }
            
            response = requests.post(
                self.api_endpoint,
                json=payload,
                headers={"Content-Type": "application/json"},
                timeout=120
            )
            response.raise_for_status()
            
            result = response.json()
            return result["choices"][0]["message"]["content"]
        
        return self._retry_with_backoff(_make_request)


# ============================================================================
# MODEL FACTORY
# ============================================================================

def create_model(model_config: Dict[str, Any]) -> BaseModel:
    """Factory function to create model instances."""
    provider = ModelProvider(model_config["provider"])
    name = model_config["name"]
    
    # Common parameters
    common_params = {
        "temperature": model_config.get("temperature", 0.7),
        "max_tokens": model_config.get("max_tokens", 4096)
    }
    
    # Add additional parameters
    if "additional_params" in model_config:
        common_params.update(model_config["additional_params"])
    
    if provider == ModelProvider.OPENAI:
        api_key = os.getenv(model_config["api_key_env"])
        if not api_key:
            raise ValueError(f"API key not found: {model_config['api_key_env']}")
        return OpenAIModel(name=name, api_key=api_key, **common_params)
    
    elif provider == ModelProvider.ANTHROPIC:
        api_key = os.getenv(model_config["api_key_env"])
        if not api_key:
            raise ValueError(f"API key not found: {model_config['api_key_env']}")
        return AnthropicModel(name=name, api_key=api_key, **common_params)
    
    elif provider == ModelProvider.HUGGINGFACE:
        use_pipeline = model_config.get("use_pipeline", True)
        device = model_config.get("device", "auto")
        return HuggingFaceModel(name=name, use_pipeline=use_pipeline, device=device, **common_params)
    
    elif provider == ModelProvider.LOCAL_HTTP:
        api_endpoint = model_config.get("api_endpoint")
        if not api_endpoint:
            raise ValueError("local_http provider requires 'api_endpoint'")
        return LocalHTTPModel(name=name, api_endpoint=api_endpoint, **common_params)
    
    else:
        raise ValueError(f"Unsupported provider: {provider}")


# ============================================================================
# DATASET LOADERS
# ============================================================================

def load_ideas_dataset(filepath: str) -> List[Task]:
    """Load the LiveIdeaBench dataset."""
    df = pd.read_csv(filepath)
    tasks = []
    
    for idx, row in df.iterrows():
        task_id = f"IDEAS-{idx+1:03d}"
        task_data = {"keywords": row["keywords"]}
        task = Task(task_id, TaskType.IDEAS, task_data)
        tasks.append(task)
    
    logger.info(f"Loaded {len(tasks)} idea generation tasks from {filepath}")
    return tasks


def load_math_dataset(filepath: str) -> List[Task]:
    """Load the OmniMath dataset."""
    df = pd.read_csv(filepath)
    tasks = []
    
    for idx, row in df.iterrows():
        task_id = f"MATH-{idx+1:03d}"
        task_data = {
            "problem": row["problem"],
            "domain": row.get("domain", ""),
            "difficulty": row.get("difficulty", ""),
            "solution": row.get("solution", ""),
            "answer": row.get("answer", "")
        }
        task = Task(task_id, TaskType.MATH, task_data)
        tasks.append(task)
    
    logger.info(f"Loaded {len(tasks)} math tasks from {filepath}")
    return tasks


def load_coding_dataset(filepath: str) -> List[Task]:
    """Load the DS1000 dataset."""
    df = pd.read_csv(filepath)
    tasks = []
    
    for idx, row in df.iterrows():
        task_id = f"CODE-{idx+1:03d}"
        task_data = {
            "prompt": row["prompt"],
            "code_context": row.get("code_context", ""),
            "library": row.get("library", ""),
            "problem_id": row.get("problem_id", ""),
            "reference_code": row.get("reference_code", "")
        }
        task = Task(task_id, TaskType.CODING, task_data)
        tasks.append(task)
    
    logger.info(f"Loaded {len(tasks)} coding tasks from {filepath}")
    return tasks


def load_all_datasets(config: Dict[str, Any]) -> Dict[TaskType, List[Task]]:
    """Load all datasets based on configuration."""
    datasets = {}
    dataset_config = config["datasets"]
    
    if "ideas" in dataset_config:
        datasets[TaskType.IDEAS] = load_ideas_dataset(dataset_config["ideas"]["filepath"])
    
    if "math" in dataset_config:
        datasets[TaskType.MATH] = load_math_dataset(dataset_config["math"]["filepath"])
    
    if "coding" in dataset_config:
        datasets[TaskType.CODING] = load_coding_dataset(dataset_config["coding"]["filepath"])
    
    return datasets


# ============================================================================
# RESULT MANAGEMENT
# ============================================================================

def get_result_path(results_dir: str, task_type: TaskType, model_name: str, 
                   task_id: str, feedback_type: FeedbackType, run_number: int) -> str:
    """Generate the file path for a result file."""
    # Clean model name for filesystem
    clean_model_name = model_name.replace("/", "_").replace(":", "_")
    
    return os.path.join(
        results_dir,
        task_type.value,
        clean_model_name,
        feedback_type.value,  # Add feedback type to path
        task_id,
        f"run_{run_number}.json"
    )


def check_run_exists(results_dir: str, task_type: TaskType, model_name: str, 
                    task_id: str, feedback_type: FeedbackType, run_number: int) -> bool:
    """Check if a complete run file already exists."""
    output_path = get_result_path(results_dir, task_type, model_name, task_id, feedback_type, run_number)
    return os.path.exists(output_path)


def load_existing_run(results_dir: str, task_type: TaskType, model_name: str, 
                     task_id: str, feedback_type: FeedbackType, run_number: int) -> Tuple[List[Dict[str, Any]], int]:
    """Load existing run data if it exists."""
    output_path = get_result_path(results_dir, task_type, model_name, task_id, feedback_type, run_number)
    
    if not os.path.exists(output_path):
        return [], 0
    
    try:
        with open(output_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        turns = data.get('turns', [])
        last_turn = len(turns)
        return turns, last_turn
        
    except Exception as e:
        logger.error(f"Error loading existing run from {output_path}: {e}")
        return [], 0


def save_run_result(results_dir: str, task_type: TaskType, model_name: str, task: Task,
                   feedback_type: FeedbackType, run_number: int, run_turns_data: List[Dict[str, Any]]) -> str:
    """Save a complete run to a single JSON file."""
    output_path = get_result_path(results_dir, task_type, model_name, task.task_id, feedback_type, run_number)
    
    # Create directory structure
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Prepare complete run data
    result_data = {
        "model_name": model_name,
        "task_id": task.task_id,
        "task_type": task_type.value,
        "feedback_type": feedback_type.value,  # Add feedback type to data
        "run_number": run_number,
        "total_turns": len(run_turns_data),
        "run_timestamp": datetime.utcnow().isoformat(),
        "original_prompt": task.prompt_text,
        "task_data": task.data,
        "turns": run_turns_data
    }
    
    # Save to JSON file
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(result_data, f, indent=2, ensure_ascii=False)
    
    return output_path


# ============================================================================
# MAIN EXPERIMENT RUNNER
# ============================================================================

def run_experiment_for_task_type(config: Dict[str, Any], models: List[BaseModel],
                                tasks: List[Task], task_type: TaskType,
                                results_dir: str) -> Tuple[int, int]:
    """Run the experiment for a specific task type."""
    params = config["parameters"]
    improvement_iterations = params["improvement_iterations"]
    runs_per_task = params["runs_per_task"]
    
    # Get enabled feedback types from config, or use all by default
    enabled_feedback_types = config.get("enabled_feedback_types", [
        "v1_improve", "v2_better", "v3_refine", 
        "s1_novel", "s1_performance", "s1_elaboration",
        "s2_practical", "s2_maintainability", "s2_exploration"
    ])
    
    completed_runs = 0
    skipped_runs = 0
    
    logger.info(f"Starting {task_type.value} experiment")
    logger.info(f"Enabled feedback types: {enabled_feedback_types}")
    
    # Loop through each model
    for model_idx, model in enumerate(models, 1):
        logger.info(f"  [{task_type.value}] Processing model {model_idx}/{len(models)}: {model.name}")
        
        # Loop through each task
        for task_idx, task in enumerate(tasks, 1):
            logger.info(f"    [{task_type.value}] Processing task {task_idx}/{len(tasks)}: {task.task_id}")
            
            # Get valid feedback types for this task
            valid_feedback_types = task.get_valid_feedback_types()
            
            # Filter by enabled feedback types
            feedback_types_to_run = [
                fb_type for fb_type in valid_feedback_types 
                if fb_type.value in enabled_feedback_types
            ]
            
            logger.info(f"      [{task_type.value}] Running feedback types: {[ft.value for ft in feedback_types_to_run]}")
            
            # Loop through each feedback type
            for feedback_type in feedback_types_to_run:
                logger.info(f"      [{task_type.value}] Processing feedback type: {feedback_type.value}")
                
                # Multiple runs for statistical robustness
                for run_number in range(1, runs_per_task + 1):
                    logger.info(f"        [{task_type.value}] Run {run_number}/{runs_per_task}")
                    
                    # Check if complete run already exists
                    if check_run_exists(results_dir, task_type, model.name, task.task_id, feedback_type, run_number):
                        logger.info(f"          Skipping - complete run already exists")
                        skipped_runs += 1
                        continue
                    
                    # Load any existing partial run data
                    existing_turns, last_completed_turn = load_existing_run(
                        results_dir, task_type, model.name, task.task_id, feedback_type, run_number
                    )
                    
                    # Initialize run data
                    run_turns_data = existing_turns.copy()
                    current_output = ""
                    improvement_instruction = task.get_improvement_instruction(feedback_type)
                    
                    # If resuming, get the last output
                    if last_completed_turn > 0:
                        logger.info(f"          Resuming from turn {last_completed_turn + 1}")
                        current_output = existing_turns[-1]['response']
                    
                    # Iterative improvement loop
                    for turn_number in range(last_completed_turn + 1, improvement_iterations + 1):
                        logger.info(f"          [{task_type.value}] Turn {turn_number}/{improvement_iterations}")
                        
                        # Construct prompt for this turn
                        if turn_number == 1:
                            # First turn: use original task prompt
                            prompt_for_turn = task.prompt_text
                        else:
                            # Subsequent turns: ask for improvement
                            prompt_for_turn = f"""The following is a previous response:

---

{current_output}

---

{improvement_instruction}"""
                        
                        # Generate response
                        try:
                            logger.info(f"            Generating response...")
                            response_text = model.generate(prompt_for_turn)
                            current_output = response_text  # Update for next turn
                            
                            # Prepare turn data
                            turn_data = {
                                "turn_number": turn_number,
                                "prompt": prompt_for_turn,
                                "response": response_text,
                                "feedback_type": feedback_type.value,
                                "improvement_instruction": improvement_instruction,
                                "timestamp": datetime.utcnow().isoformat(),
                                "is_initial_turn": turn_number == 1
                            }
                            
                            run_turns_data.append(turn_data)
                            
                            # Save run after each turn (for resilience)
                            output_path = save_run_result(
                                results_dir, task_type, model.name, task, feedback_type, run_number, run_turns_data
                            )
                            
                            logger.info(f"            Turn completed. Run saved to: {os.path.basename(output_path)}")
                            
                        except Exception as e:
                            logger.error(f"            Error generating response: {e}")
                            logger.info(f"            Continuing to next run...")
                            break
                    
                    # Run completed
                    if run_turns_data:
                        completed_runs += 1
                        logger.info(f"          Run {run_number} completed with {len(run_turns_data)} turns")
    
    return completed_runs, skipped_runs


def run_experiment(config_path: str = "config.yaml", results_dir: str = "results") -> None:
    """Main experiment runner function."""
    logger.info("Starting Multi-Task Iterative Improvement Experiment with Multiple Feedback Types")
    
    # Load environment variables
    if load_dotenv:
        load_dotenv()
    
    # Load configuration
    try:
        with open(config_path, 'r', encoding='utf-8') as f:
            config = yaml.safe_load(f)
        
        params = config["parameters"]
        logger.info(f"Experiment parameters:")
        logger.info(f"  Improvement iterations: {params['improvement_iterations']}")
        logger.info(f"  Runs per task: {params['runs_per_task']}")
        
        enabled_feedback_types = config.get("enabled_feedback_types", [])
        if enabled_feedback_types:
            logger.info(f"  Enabled feedback types: {enabled_feedback_types}")
        else:
            logger.info("  Using all available feedback types")
        
    except Exception as e:
        logger.error(f"Failed to load configuration: {e}")
        raise
    
    # Load datasets
    try:
        datasets = load_all_datasets(config)
        total_tasks = sum(len(tasks) for tasks in datasets.values())
        logger.info(f"Loaded {len(datasets)} task types with {total_tasks} total tasks")
    except Exception as e:
        logger.error(f"Failed to load datasets: {e}")
        raise
    
    # Initialize models
    models = []
    for model_config in config["models"]:
        try:
            model = create_model(model_config)
            models.append(model)
            logger.info(f"Successfully initialized model: {model.name}")
        except Exception as e:
            logger.error(f"Failed to initialize model {model_config['name']}: {e}")
            continue
    
    if not models:
        logger.error("No models were successfully initialized. Exiting.")
        return
    
    logger.info(f"Initialized {len(models)} models")
    
    # Create results directory
    os.makedirs(results_dir, exist_ok=True)
    
    # Run experiments for each task type
    overall_completed = 0
    overall_skipped = 0
    
    for task_type, tasks in datasets.items():
        logger.info("=" * 80)
        logger.info(f"RUNNING {task_type.value.upper()} EXPERIMENT")
        logger.info("=" * 80)
        
        completed, skipped = run_experiment_for_task_type(
            config, models, tasks, task_type, results_dir
        )
        
        overall_completed += completed
        overall_skipped += skipped
        
        logger.info(f"{task_type.value.upper()} EXPERIMENT SUMMARY:")
        logger.info(f"  Runs completed: {completed}")
        logger.info(f"  Runs skipped: {skipped}")
    
    # Overall summary
    logger.info("=" * 80)
    logger.info("OVERALL EXPERIMENT COMPLETED")
    logger.info("=" * 80)
    logger.info(f"Total runs completed: {overall_completed}")
    logger.info(f"Total runs skipped: {overall_skipped}")
    logger.info(f"Results saved in: {os.path.abspath(results_dir)}")


# ============================================================================
# MAIN ENTRY POINT
# ============================================================================

def main():
    """Main entry point for the script."""
    try:
        run_experiment()
    except KeyboardInterrupt:
        logger.info("\nExperiment interrupted by user. Progress has been saved.")
        logger.info("You can resume by running the script again.")
    except Exception as e:
        logger.error(f"Experiment failed with error: {e}")
        raise


if __name__ == "__main__":
    main()