import pandas as pd
import os
from datetime import datetime

from typing import List, Optional
from src.dataset_processing.common.models.dataset_entry import DatasetEntry
from src.loggers.setup_logging import setup_logging
from src.reliability_eval.pipeline.config import IntegratedPipelineConfig
from src.reliability_eval.pipeline.evaluation_pipelines.types import PipelineType
from src.reliability_eval.pipeline.results import EvaluationResults

logger = setup_logging()

def get_project_path(path):
    value = os.environ.get(path)
    if not value:
        raise ValueError(f"Required environment variable '{path}' is missing or empty in .env file")
    return value

class ExcelWriter:
    """Class to handle writing evaluation results to Excel files."""
    @staticmethod
    def _generate_evaluation_excel(
        results: EvaluationResults,
        config: IntegratedPipelineConfig,
        dataset_entries: List[DatasetEntry],
        num_max_entries: Optional[int] = None
    ) -> None:
        """Generates detailed Excel report of evaluation results."""
        logger.info("Generating evaluation Excel report")
        
        # Create results directory structure
        base_path = os.path.join(get_project_path('PROJECT_PATH'), 'results')
        timestamp = datetime.now().strftime("%Y%m%d")
        exp_id = config.exp_id
        model_name = str(config.model_identifier).replace("/", "_")
        dataset_type = str(config.dataset_config.dataset_type.value)
        perturbation_type = str(config.dataset_config.perturbation_type.value) if hasattr(config.dataset_config, 'perturbation_type') else "none"
        perturbation_intensity = str(config.dataset_config.perturbation_intensity) if hasattr(config.dataset_config, 'perturbation_intensity') else "none"
        num_entries = len(dataset_entries)
        num_repeats = config.generation_config.num_repeats
        temperature = config.generation_config.temperature
        max_new_tokens = config.generation_config.max_new_tokens
        random_seed = config.random_seed
        config_string = f"exp-{exp_id}_model-{model_name}_dataset-{dataset_type}_pert-{perturbation_type}_int-{perturbation_intensity}_entries-{num_entries}_repeats-{num_repeats}_temp-{temperature}_tokens-{max_new_tokens}_seed-{random_seed}_time-{timestamp}"
        result_dir = base_path / config_string
        os.makedirs(result_dir, exist_ok=True)
        excel_path = result_dir / f"results_{config_string}.xlsx"
        
        with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
            # Write all results to a single sheet
            ExcelWriter._write_results_to_excel(
                writer=writer,
                results=results,
                config=config,
                dataset_entries=dataset_entries,
                sheet_name='Results',
                num_max_entries=num_max_entries
            )
            
        logger.info(f"Saved evaluation results to {excel_path}")
    
    @staticmethod
    def _write_results_to_excel(
        writer: pd.ExcelWriter,
        results: EvaluationResults,
        config: IntegratedPipelineConfig,
        dataset_entries: List[DatasetEntry],
        sheet_name: str,
        num_max_entries: Optional[int] = None
    ) -> None:
        """Writes all results to a single sheet."""
        # Prepare data for DataFrame
        rows = []
        for idx, entry in enumerate(dataset_entries):
            if num_max_entries and idx >= num_max_entries:
                break
                
            row = {
                # General information
                'entry_idx': idx,
                'exp_id': config.exp_id,
                
                # Dataset entry information
                'question': entry.question,
                'answer': entry.answer,
                'dataset_name': config.dataset_config.dataset_name,
                'perturbation_type': entry.metadata.get('perturbation_type',
                    config.dataset_config.perturbation_type.value if hasattr(config.dataset_config, 'perturbation_type') else None),
                'perturbation_intensity': entry.metadata.get('perturbation_intensity',
                    config.dataset_config.perturbation_intensity if hasattr(config.dataset_config, 'perturbation_intensity') else None),
                
                # Model configuration
                'model': str(config.model_identifier),
                'device': config.device,
                'seed': config.random_seed,
                
                # Generation configuration
                'generation_strategy': config.generation_config.generation_strategy.value,
                'prompt_strategy': config.generation_config.prompt_strategy.value,
                'num_repeats': config.generation_config.num_repeats,
                'max_new_tokens': config.generation_config.max_new_tokens,
                'temperature': config.generation_config.temperature
            }
            
            # Add individual pipeline scores
            for pipeline_name, score_tensor in results.scores.items():
                pipeline_name = pipeline_name.value if isinstance(pipeline_name, PipelineType) else pipeline_name
                if idx >= 16:
                    print(pipeline_name)
                try:
                    if idx < len(score_tensor):
                        if pipeline_name == 'token_info':  # New field
                            token_info = score_tensor[idx]
                            row['token_info_tuples'] = str(token_info)
                            # Extract first token probability from token_info nested structure
                            try:
                                first_token_prob = eval(str(token_info))[0][0][2]  # Access [0][0][2] for first token's probability
                                row['first_token_probability'] = first_token_prob
                            except (IndexError, ValueError, SyntaxError):
                                row['first_token_probability'] = None
                        elif pipeline_name == 'full_text':  # New field
                            full_text = score_tensor[idx]
                            row['full_decoded_text'] = str(full_text)
                        elif pipeline_name == 'effective_lengths':  # New field
                            effective_lengths = score_tensor[idx]
                            row['effective_lengths'] = str(effective_lengths)
                        elif pipeline_name == 'padded_sequences':  # New field
                            padded_sequences = score_tensor[idx]
                            row['padded_sequences'] = str(padded_sequences)
                        elif pipeline_name == 'semantic_entropy':  # New field
                            semantic_entropy = score_tensor[idx]
                            row['semantic_entropy'] = str(semantic_entropy)
                        elif pipeline_name != "accuracy":
                            score_value = score_tensor[idx].item() if idx < len(score_tensor) else None
                            row[f'{pipeline_name}_score'] = score_value
                except Exception as e:
                    logger.warning(f"Failed to process {pipeline_name} for idx {idx}: {str(e)}")
                    row[f'{pipeline_name}'] = None
            
            # Add accuracy if available
            if "accuracy" in results.scores:
                row['accuracy'] = results.scores["accuracy"][idx].item() if idx < len(results.scores["accuracy"]) else None
            
            # Add metric results (existing code)
            for pipeline_name, metrics in results.metrics.items():
                pipeline_name = pipeline_name.value if isinstance(pipeline_name, PipelineType) else pipeline_name
                if metrics.aucpr is not None:
                    row[f'{pipeline_name}_aucpr'] = metrics.aucpr
                if metrics.aucroc is not None:
                    row[f'{pipeline_name}_aucroc'] = metrics.aucroc
                if metrics.brier is not None:
                    row[f'{pipeline_name}_brier'] = metrics.brier
                if metrics.mean_scores is not None:
                    row[f'{pipeline_name}_mean'] = metrics.mean_scores
                    
            rows.append(row)
        
        # Create DataFrame and ensure we have data
        df = pd.DataFrame(rows)
        
        # Define column order
        ordered_columns = [
            'entry_idx',
            'question', 
            'answer',
            'accuracy',
            'full_decoded_text',
            'first_token_probability',
            'token_info_tuples',
            'effective_lengths',
            'padded_sequences',
            'semantic_entropy'
            'exp_id',
            'num_repeats',
            'dataset_name',
            'perturbation_type',
            'perturbation_intensity',
            'model',
            'generation_strategy',
            'prompt_strategy',
            'max_new_tokens'
        ]
        
        # Get remaining columns in their current order
        remaining_columns = [col for col in df.columns if col not in ordered_columns]
        
        # Combine ordered and remaining columns
        final_column_order = ordered_columns + remaining_columns
        
        # Reorder DataFrame columns
        df = df.reindex(columns=[col for col in final_column_order if col in df.columns])
        
        if not df.empty:
            df.to_excel(writer, index=False, sheet_name=sheet_name)
            
            # Auto-adjust column widths (existing code)
            worksheet = writer.sheets[sheet_name]
            for column in worksheet.columns:
                max_length = 0
                column = [cell for cell in column]
                for cell in column:
                    try:
                        if len(str(cell.value)) > max_length:
                            max_length = len(cell.value)
                    except:  # noqa: E722
                        pass
                adjusted_width = (max_length + 2)
                worksheet.column_dimensions[column[0].column_letter].width = min(adjusted_width, 50)