"""
Batch Feature Extractor
Unified processing for all folders under dataset with offline storage functionality
"""

import os
import json
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
from pathlib import Path
import time
from datetime import datetime
import pickle
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

from .compact_feature_extractor import CompactSATFeatureExtractor, CompactSATFeatures


class BatchFeatureExtractor:
    """Batch feature extractor"""
    
    def __init__(self, dataset_root: str = "dataset", output_dir: str = "features", 
                 max_workers: int = 4, skip_existing: bool = True):
        """
        Initialize batch feature extractor
        
        Args:
            dataset_root: dataset directory path
            output_dir: feature output directory
            max_workers: maximum number of parallel workers
            skip_existing: whether to skip existing feature files
        """
        self.dataset_root = Path(dataset_root)
        self.output_dir = Path(output_dir)
        self.max_workers = max_workers
        self.skip_existing = skip_existing
        self.extractor = CompactSATFeatureExtractor()
        
        # Create output directory
        self.output_dir.mkdir(exist_ok=True)
        
        # Setup logging
        self._setup_logging()
        
        # Target folder list will be set when needed
        self.target_folders = None
        
    def _setup_logging(self):
        """Setup logging"""
        log_file = self.output_dir / f"extraction_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
        
    def _get_target_folders(self) -> List[str]:
        """Get target folder list to process"""
        if not self.dataset_root.exists():
            raise FileNotFoundError(f"Dataset root not found: {self.dataset_root}")
        
        all_folders = [f.name for f in self.dataset_root.iterdir() if f.is_dir()]
        # Filter out *_train folders but keep train folder
        target_folders = []
        for folder in all_folders:
            if folder.endswith('_train') and folder != 'train':
                continue
            target_folders.append(folder)
        
        target_folders.sort()
        self.logger.info(f"Found {len(target_folders)} target folders: {target_folders}")
        return target_folders
    
    def extract_single_file(self, cnf_file_path: str) -> Tuple[str, Optional[np.ndarray], Optional[str]]:
        """
        Returns:
            (filename, feature_vector, error_message)
        """
        filename = os.path.basename(cnf_file_path)
        try:
            start_time = time.time()
            feature_vector = self.extractor.extract_features_to_vector(cnf_file_path)
            extraction_time = time.time() - start_time
            
            self.logger.debug(f"Extracted features for {filename} in {extraction_time:.2f}s")
            return filename, feature_vector, None
            
        except Exception as e:
            error_msg = f"Error extracting features from {filename}: {str(e)}"
            self.logger.error(error_msg)
            return filename, None, error_msg
    
    def extract_folder_features(self, folder_name: str) -> Dict:
        folder_path = self.dataset_root / folder_name
        if not folder_path.exists():
            raise FileNotFoundError(f"Folder not found: {folder_path}")
        
        cnf_files = list(folder_path.glob("*.cnf"))
        if not cnf_files:
            self.logger.warning(f"No CNF files found in {folder_path}")
            return {
                'folder_name': folder_name,
                'file_count': 0,
                'features': {},
                'errors': {},
                'extraction_time': 0,
                'feature_matrix': np.array([]),
                'filenames': []
            }
        
        self.logger.info(f"Processing {len(cnf_files)} CNF files in {folder_name}")
        
        start_time = time.time()
        features = {}
        errors = {}
        feature_vectors = []
        filenames = []
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            future_to_file = {
                executor.submit(self.extract_single_file, str(cnf_file)): cnf_file.name 
                for cnf_file in cnf_files
            }
            for future in tqdm(as_completed(future_to_file), total=len(cnf_files), 
                             desc=f"Extracting {folder_name}"):
                filename = future_to_file[future]
                result_filename, feature_vector, error_msg = future.result()
                
                if feature_vector is not None:
                    features[result_filename] = feature_vector.tolist()
                    feature_vectors.append(feature_vector)
                    filenames.append(result_filename)
                else:
                    errors[result_filename] = error_msg
        
        extraction_time = time.time() - start_time
        
        feature_matrix = np.array(feature_vectors, dtype=np.float32) if feature_vectors else np.array([])
        
        result = {
            'folder_name': folder_name,
            'file_count': len(cnf_files),
            'successful_extractions': len(features),
            'failed_extractions': len(errors),
            'features': features,
            'errors': errors,
            'extraction_time': extraction_time,
            'feature_matrix': feature_matrix,
            'filenames': filenames,
            'feature_names': self.extractor.get_feature_names(),
            'extraction_date': datetime.now().isoformat()
        }
        
        self.logger.info(f"Completed {folder_name}: {len(features)} successful, "
                        f"{len(errors)} failed, {extraction_time:.2f}s")
        
        return result
    
    def save_folder_features(self, folder_result: Dict, formats: List[str] = ['npz', 'json']):
        folder_name = folder_result['folder_name']
        folder_output_dir = self.output_dir / folder_name
        folder_output_dir.mkdir(exist_ok=True)
        
        if 'npz' in formats:
            npz_file = folder_output_dir / f"{folder_name}_features.npz"
            np.savez_compressed(
                npz_file,
                features=folder_result['feature_matrix'],
                filenames=np.array(folder_result['filenames']),
                feature_names=np.array(folder_result['feature_names']),
                folder_name=folder_name,
                extraction_date=folder_result['extraction_date']
            )
            self.logger.info(f"Saved NPZ features: {npz_file}")
        
        if 'json' in formats:
            json_file = folder_output_dir / f"{folder_name}_features.json"
            json_data = {
                'folder_name': folder_result['folder_name'],
                'file_count': folder_result['file_count'],
                'successful_extractions': folder_result['successful_extractions'],
                'failed_extractions': folder_result['failed_extractions'],
                'features': folder_result['features'],
                'errors': folder_result['errors'],
                'extraction_time': folder_result['extraction_time'],
                'feature_names': folder_result['feature_names'],
                'extraction_date': folder_result['extraction_date']
            }
            with open(json_file, 'w') as f:
                json.dump(json_data, f, indent=2)
            self.logger.info(f"Saved JSON features: {json_file}")
        
        if 'csv' in formats and len(folder_result['feature_matrix']) > 0:
            csv_file = folder_output_dir / f"{folder_name}_features.csv"
            df = pd.DataFrame(
                folder_result['feature_matrix'],
                columns=folder_result['feature_names'],
                index=folder_result['filenames']
            )
            df.to_csv(csv_file)
            self.logger.info(f"Saved CSV features: {csv_file}")
        
        if 'pickle' in formats:
            pickle_file = folder_output_dir / f"{folder_name}_features.pkl"
            with open(pickle_file, 'wb') as f:
                pickle.dump(folder_result, f)
            self.logger.info(f"Saved Pickle features: {pickle_file}")
    
    def extract_all_features(self, formats: List[str] = ['npz', 'json']) -> Dict:
        overall_start_time = time.time()
        overall_results = {
            'extraction_summary': {},
            'total_files': 0,
            'total_successful': 0,
            'total_failed': 0,
            'total_time': 0,
            'folders_processed': []
        }
        
        self.logger.info(f"Starting batch feature extraction for {len(self.target_folders)} folders")
        
        for folder_name in self.target_folders:
            if self.skip_existing:
                existing_npz = self.output_dir / folder_name / f"{folder_name}_features.npz"
                if existing_npz.exists():
                    self.logger.info(f"Skipping {folder_name} (already exists)")
                    continue
            
            try:
                folder_result = self.extract_folder_features(folder_name)
                self.save_folder_features(folder_result, formats)
                
                overall_results['extraction_summary'][folder_name] = {
                    'file_count': folder_result['file_count'],
                    'successful': folder_result['successful_extractions'],
                    'failed': folder_result['failed_extractions'],
                    'time': folder_result['extraction_time']
                }
                overall_results['total_files'] += folder_result['file_count']
                overall_results['total_successful'] += folder_result['successful_extractions']
                overall_results['total_failed'] += folder_result['failed_extractions']
                overall_results['folders_processed'].append(folder_name)
                
            except Exception as e:
                self.logger.error(f"Failed to process folder {folder_name}: {str(e)}")
                overall_results['extraction_summary'][folder_name] = {'error': str(e)}
        
        overall_results['total_time'] = time.time() - overall_start_time
        
        summary_file = self.output_dir / f"extraction_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(summary_file, 'w') as f:
            json.dump(overall_results, f, indent=2)
        
        self.logger.info(f"Batch extraction completed in {overall_results['total_time']:.2f}s")
        self.logger.info(f"Total: {overall_results['total_files']} files, "
                        f"{overall_results['total_successful']} successful, "
                        f"{overall_results['total_failed']} failed")
        self.logger.info(f"Summary saved to: {summary_file}")
        
        return overall_results
    
    def load_folder_features(self, folder_name: str, format_type: str = 'npz') -> Optional[Dict]:
        folder_path = self.output_dir / folder_name
        
        if format_type == 'npz':
            npz_file = folder_path / f"{folder_name}_features.npz"
            if npz_file.exists():
                data = np.load(npz_file, allow_pickle=True)
                return {
                    'feature_matrix': data['features'],
                    'filenames': data['filenames'].tolist(),
                    'feature_names': data['feature_names'].tolist(),
                    'folder_name': str(data['folder_name']),
                    'extraction_date': str(data['extraction_date'])
                }
        elif format_type == 'json':
            json_file = folder_path / f"{folder_name}_features.json"
            if json_file.exists():
                with open(json_file, 'r') as f:
                    return json.load(f)
        elif format_type == 'pickle':
            pickle_file = folder_path / f"{folder_name}_features.pkl"
            if pickle_file.exists():
                with open(pickle_file, 'rb') as f:
                    return pickle.load(f)
        
        return None
    
    def get_combined_features(self, folders: Optional[List[str]] = None) -> Tuple[np.ndarray, List[str], List[str]]:
        """
        
        Returns:
            (feature_matrix, file_labels, feature_names)
        """
        if folders is None:
            folders = self.target_folders
        
        all_features = []
        all_labels = []
        feature_names = None
        
        for folder_name in folders:
            folder_data = self.load_folder_features(folder_name, 'npz')
            if folder_data is not None:
                features = folder_data['feature_matrix']
                filenames = folder_data['filenames']
                
                if feature_names is None:
                    feature_names = folder_data['feature_names']
            
                labels = [f"{folder_name}_{filename}" for filename in filenames]
                
                all_features.append(features)
                all_labels.extend(labels)
        
        if all_features:
            combined_features = np.vstack(all_features)
            return combined_features, all_labels, feature_names
        else:
            return np.array([]), [], []


def main():
    extractor = BatchFeatureExtractor(
        dataset_root="dataset",
        output_dir="features",
        max_workers=4,
        skip_existing=True
    )
    
    results = extractor.extract_all_features(formats=['npz', 'json', 'csv'])
    
    print(f"Extraction completed!")
    print(f"Processed folders: {results['folders_processed']}")
    print(f"Total files: {results['total_files']}")
    print(f"Successful: {results['total_successful']}")
    print(f"Failed: {results['total_failed']}")
    print(f"Total time: {results['total_time']:.2f}s")


if __name__ == "__main__":
    main()