#!/usr/bin/env python3

import argparse
import ast
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional

from PIL import Image
from tqdm import tqdm


class AITZProcessor:
    
    def __init__(self, input_dir: str, output_dir: Optional[str] = None, image_key: str = 'image_path'):
        self.input_dir = Path(input_dir)
        self.output_dir = Path(output_dir) if output_dir else None
        self.image_key = image_key
        
        if not self.input_dir.exists():
            raise ValueError(f"Input directory does not exist: {self.input_dir}")
        
        if self.output_dir:
            self.output_dir.mkdir(parents=True, exist_ok=True)
    
    def get_image_info(self, image_path: Path) -> Optional[Dict[str, int]]:
        try:
            with Image.open(image_path) as img:
                width, height = img.size
                channels = len(img.getbands())
                return {
                    'image_width': width,
                    'image_height': height,
                    'image_channels': channels
                }
        except Exception as e:
            print(f"Warning: Cannot load image {image_path}: {e}")
            return None
    
    def normalize_ui_positions(self, positions: List[List[float]], width: int, height: int) -> List[List[float]]:
        if width <= 0 or height <= 0:
            return positions
        
        normalized = []
        for pos in positions:
            if len(pos) >= 4:
                y, x, hh, ww = pos[0], pos[1], pos[2], pos[3]
                normalized.append([y / height, x / width, hh / height, ww / width])
            else:
                normalized.append(pos)
        return normalized
    
    def process_record(self, record: Dict[str, Any], root: Path) -> bool:
        if self.image_key not in record:
            return False
        
        image_path = root / Path(record[self.image_key])
        if not image_path.exists():
            print(f"Warning: Image not found: {image_path}")
            return False
        
        image_info = self.get_image_info(image_path)
        if image_info:
            record.update(image_info)
        
        if 'ui_positions' in record:
            try:
                positions = ast.literal_eval(record['ui_positions'])
                if isinstance(positions, list) and len(positions) > 0:
                    width = record.get('image_width', 0)
                    height = record.get('image_height', 0)
                    normalized = self.normalize_ui_positions(positions, width, height)
                    record['ui_positions'] = json.dumps(normalized, ensure_ascii=False)
            except (ValueError, SyntaxError, TypeError) as e:
                print(f"Warning: Cannot parse ui_positions: {e}")
        
        return True
    
    def process_json_file(self, json_path: Path) -> bool:
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                obj = json.load(f)
        except Exception as e:
            print(f"Error reading {json_path}: {e}")
            return False
        
        records = obj if isinstance(obj, list) else [obj]
        processed = False
        
        for record in records:
            if self.process_record(record, json_path.parent):
                processed = True
        
        if processed:
            try:
                if self.output_dir:
                    relative_path = json_path.relative_to(self.input_dir)
                    output_path = self.output_dir / relative_path
                    output_path.parent.mkdir(parents=True, exist_ok=True)
                    with open(output_path, 'w', encoding='utf-8') as f:
                        json.dump(obj, f, ensure_ascii=False, indent=2)
                else:
                    with open(json_path, 'w', encoding='utf-8') as f:
                        json.dump(obj, f, ensure_ascii=False, indent=2)
                return True
            except Exception as e:
                print(f"Error writing file: {e}")
                return False
        
        return False
    
    def run(self):
        json_files = list(self.input_dir.rglob('*.json'))
        if not json_files:
            print(f"No JSON files found in {self.input_dir}")
            return
        
        processed_count = 0
        for json_path in tqdm(json_files, desc="Processing JSON files"):
            if self.process_json_file(json_path):
                processed_count += 1
        
        print(f"Processed {processed_count}/{len(json_files)} files")


def main():
    parser = argparse.ArgumentParser(description='Process AITZ dataset')
    parser.add_argument('--input_dir', type=str, default='/INPUT_DIR',
                       help='Input directory containing JSON files')
    parser.add_argument('--output_dir', type=str, default=None,
                       help='Output directory for processed files (if not specified, files are modified in place)')
    parser.add_argument('--image_key', type=str, default='image_path',
                       help='Field name that points to the image path')
    
    args = parser.parse_args()
    
    try:
        processor = AITZProcessor(
            input_dir=args.input_dir,
            output_dir=args.output_dir,
            image_key=args.image_key
        )
        processor.run()
        print("Done")
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)


if __name__ == '__main__':
    main()
