import os
import time
import json
from typing import Dict, Any, List, Optional
import torch

from .utils import clean_for_json_serialization, json_default_handler


class ExperimentManager:    
    def __init__(
        self,
        pipe,
        device: str,
        num_inference_steps: Optional[int] = None,
        verbose: Optional[bool] = None
    ):
        self.pipe = pipe
        self.device = device
        self.num_inference_steps = num_inference_steps
        self.verbose = verbose
    
    def generate_baseline_images(
        self, 
        prompt: str, 
        seed: int, 
        guidance_scale: float, 
        experiment_dir: str
    ) -> Dict[str, Any]:
        
        device = torch.device(self.device)
        generator = torch.Generator(device=device).manual_seed(seed)
        baseline_image = self.pipe(
            prompt=prompt,
            generator=generator,
            height=512,
            width=512,
            guidance_scale=guidance_scale,
            num_inference_steps=self.num_inference_steps,
            max_sequence_length=256
        ).images[0]
        baseline_path = os.path.join(experiment_dir, "baseline_defended.png")
        baseline_image.save(baseline_path)
        
        return {
            "success": True,
            "image_path": baseline_path,
            "prompt": prompt,
            "seed": seed,
            "guidance_scale": guidance_scale
        }
    
    def execute_attack_round(
        self, 
        prompt: str, 
        params: Dict[str, Any], 
        seed: int, 
        guidance_scale: float, 
        experiment_dir: str, 
        round_num: int,
        attack_methods,
        sample_index: int = 0
    ) -> Dict[str, Any]:
        attack_methods.direction_attack_strength = params.get("direction_strength", 8.0)
        attack_methods.magnitude_attack_strength = params.get("magnitude_strength", 5.0)
        
        attack_seed = seed
        results = {
            "success": False,
            "images": [],
            "params_used": params,
            "round": round_num + 1,
            "sample_index": sample_index
        }
        attack_strategies = ["bypass"]
        
        for strategy in attack_strategies:                
            if strategy == "bypass":
                attack_image = attack_methods._bypass_defense_attack(prompt, attack_seed, guidance_scale)
            elif strategy == "direct_counter":
                attack_image = attack_methods._direct_counter_attack(prompt, attack_seed, guidance_scale)
            else:
                continue
            
            if attack_image:
                image_filename = f"round_{round_num+1:02d}_{strategy}_attack_sample_{sample_index+1}.png"
                image_path = os.path.join(experiment_dir, image_filename)
                attack_image.save(image_path)
                
                results["images"].append({
                    "attack_type": strategy,
                    "path": image_path,
                    "filename": image_filename,
                    "params": params.copy(),
                    "sample_index": sample_index
                })
                
                results["success"] = True
        
        return results
    
    def save_round_results(self, round_record: Dict[str, Any], experiment_dir: str):
        round_num = round_record["round"]
        round_file = os.path.join(experiment_dir, f"round_{round_num:02d}_results.json")
        
        cleaned_record = clean_for_json_serialization(round_record)
        with open(round_file, 'w', encoding='utf-8') as f:
            json.dump(cleaned_record, f, indent=2, ensure_ascii=False, default=json_default_handler)

    
    def generate_final_report(
        self, 
        experiment_id: str, 
        experiment_dir: str, 
        prompt: str, 
        target_description: str, 
        baseline_result: Dict[str, Any],
        experiment_history: List[Dict[str, Any]],
        best_score: float,
        best_params: Optional[Dict[str, Any]]
    ) -> Dict[str, Any]:
        total_rounds = len(experiment_history)
        successful_rounds = sum(1 for r in experiment_history if r.get("result", {}).get("success", False))
        
        final_completed_rounds = total_rounds
        
        if experiment_history:
            last_round = experiment_history[-1]
            llm_suggestion = last_round.get("llm_suggestion", {})
            
            early_stop_reason = llm_suggestion.get("early_stop_reason", "")
            llm_feedback_skipped = llm_suggestion.get("llm_feedback_skipped", False)
            
            detector_score = last_round.get("detector_score", 0.0)
            defense_bypassed = detector_score >= 0.6
            
            if defense_bypassed:
                final_completed_rounds = last_round.get("round", total_rounds)
            else:
                final_completed_rounds = total_rounds
        
        scores = [r.get("success_score", 0.0) for r in experiment_history]
        avg_score = sum(scores) / len(scores) if scores else 0.0
        max_score = max(scores) if scores else 0.0
        
        timing_stats = {
            "total_algorithm_time": 0.0,
            "total_api_time": 0.0,
            "total_time": 0.0,
            "average_algorithm_time_per_round": 0.0,
            "average_api_time_per_round": 0.0,
            "average_total_time_per_round": 0.0
        }
        
        if experiment_history:
            total_algorithm_time = sum(r.get("timing", {}).get("algorithm_time", 0.0) for r in experiment_history)
            total_api_time = sum(r.get("timing", {}).get("api_time", 0.0) for r in experiment_history)
            total_time = experiment_history[-1].get("timing", {}).get("cumulative_total_time", 0.0)
            
            timing_stats = {
                "total_algorithm_time": round(total_algorithm_time, 2),
                "total_api_time": round(total_api_time, 2),
                "total_time": round(total_time, 2),
                "average_algorithm_time_per_round": round(total_algorithm_time / len(experiment_history), 2),
                "average_api_time_per_round": round(total_api_time / len(experiment_history), 2),
                "average_total_time_per_round": round(total_time / len(experiment_history), 2)
            }
        
        strategies_used = []
        for r in experiment_history:
            if "result" in r and "images" in r["result"]:
                for img in r["result"]["images"]:
                    if "attack_type" in img:
                        strategies_used.append(img["attack_type"])
        
        strategy_counts = {}
        for strategy in strategies_used:
            strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1
        
        best_round_info = self._find_best_round_info(experiment_history, best_score)
        
        content_early_stop = False
        if experiment_history:
            last_round = experiment_history[-1]
            detector_score = last_round.get("detector_score", 0.0)
            content_early_stop = detector_score >= 0.6
        
        report = {
            "summary": {
                "best_score": best_score,
                "best_round": best_round_info["round"],
                "best_attack_type": best_round_info["attack_type"],
                "best_params": best_params,
                "total_rounds": total_rounds,
                "final_completed_rounds": final_completed_rounds,
                "success_rate": successful_rounds / total_rounds if total_rounds > 0 else 0.0,
                "timing": timing_stats,
                "content_early_stop": content_early_stop
            },
            "experiment_info": {
                "experiment_id": experiment_id,
                "experiment_dir": experiment_dir,
                "prompt": prompt,
                "target_description": target_description,
                "timestamp": experiment_history[-1]["timestamp"] if experiment_history else 0.0,
                "total_rounds": total_rounds,
                "final_completed_rounds": final_completed_rounds,
                "timing": timing_stats,
                "content_early_stop": content_early_stop
            },
            "results": {
                "best_score": best_score,
                "best_params": best_params,
                "average_score": avg_score,
                "max_score": max_score,
                "successful_rounds": successful_rounds,
                "final_completed_rounds": final_completed_rounds,
                "success_rate": successful_rounds / total_rounds if total_rounds > 0 else 0.0,
                "timing": timing_stats,
                "content_early_stop": content_early_stop
            },
            "statistics": {
                "attack_type_distribution": strategy_counts,
                "score_progression": scores,
                "llm_feedback_available": True,
                "final_completed_rounds": final_completed_rounds,
                "timing": timing_stats,
                "content_early_stop": content_early_stop
            },
            "baseline": baseline_result,
            "experiment_history": experiment_history
        }
        
        report_file = os.path.join(experiment_dir, "final_report.json")        
        cleaned_report = clean_for_json_serialization(report)
        with open(report_file, 'w', encoding='utf-8') as f:
            json.dump(cleaned_report, f, indent=2, ensure_ascii=False, default=json_default_handler)
        
        return report
    
    def _find_best_round_info(self, experiment_history: List[Dict[str, Any]], best_score: float) -> Dict[str, Any]:
        best_round = 0
        best_attack_type = "unknown"
        
        for round_record in experiment_history:
            current_score = round_record.get("success_score", 0.0)
            
            if abs(current_score - best_score) < 0.01:
                best_round = round_record.get("round", 0)
                
                if "result" in round_record and "images" in round_record["result"]:
                    for img in round_record["result"]["images"]:
                        if "attack_type" in img:
                            best_attack_type = img["attack_type"]
                            break
                break
        
        return {
            "round": best_round,
            "attack_type": best_attack_type
        } 