#!/usr/bin/env python3
"""
Test script for S3 bucket operations using boto3.
Tests basic upload and download functionality for data storage.
"""

import boto3
import os
import tempfile
import json
from botocore.exceptions import ClientError, NoCredentialsError
import logging
from typing import Optional, Dict

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class S3BucketTester:
    """Test class for S3 bucket operations"""
    
    def __init__(self, bucket_name: str, region_name: str = 'eu-central-1', 
                 aws_access_key_id: str = None, aws_secret_access_key: str = None,
                 aws_session_token: str = None):
        """
        Initialize S3 bucket tester
        
        Args:
            bucket_name: Name of the S3 bucket
            region_name: AWS region name
            aws_access_key_id: AWS access key ID (optional, uses default credentials if None)
            aws_secret_access_key: AWS secret access key (optional)
            aws_session_token: AWS session token (optional, for temporary credentials)
        """
        self.bucket_name = bucket_name
        self.region_name = region_name
        self.aws_access_key_id = aws_access_key_id
        self.aws_secret_access_key = aws_secret_access_key
        self.aws_session_token = aws_session_token
        self.s3_client = None
        
    def setup_s3_connection(self) -> bool:
        """
        Setup S3 client connection
        
        Returns:
            bool: True if connection successful, False otherwise
        """
        try:
            # Create S3 client with or without explicit credentials
            if self.aws_access_key_id and self.aws_secret_access_key:
                logger.info("Using provided IAM credentials")
                self.s3_client = boto3.client(
                    's3',
                    region_name=self.region_name,
                    aws_access_key_id=self.aws_access_key_id,
                    aws_secret_access_key=self.aws_secret_access_key,
                    aws_session_token=self.aws_session_token
                )
            else:
                logger.info("Using default AWS credentials")
                self.s3_client = boto3.client('s3', region_name=self.region_name)
            
            # Test connection
            self.s3_client.list_buckets()
            logger.info("S3 connection established")
            return True
        except NoCredentialsError:
            logger.error("AWS credentials not found")
            return False
        except ClientError as e:
            logger.error(f"AWS client error: {e}")
            return False
        except Exception as e:
            logger.error(f"Unexpected error: {e}")
            return False
    
    def check_bucket_exists(self) -> bool:
        """
        Check if the specified bucket exists
        
        Returns:
            bool: True if bucket exists, False otherwise
        """
        try:
            self.s3_client.head_bucket(Bucket=self.bucket_name)
            logger.info(f"Bucket '{self.bucket_name}' exists")
            return True
        except ClientError as e:
            error_code = int(e.response['Error']['Code'])
            if error_code == 404:
                logger.error(f"Bucket '{self.bucket_name}' does not exist")
            else:
                logger.error(f"Error checking bucket: {e}")
            return False
    
    def upload_file(self, local_file_path: str, s3_key: str) -> bool:
        """
        Upload a file to S3 bucket
        
        Args:
            local_file_path: Path to local file
            s3_key: S3 object key (path in bucket)
            
        Returns:
            bool: True if upload successful, False otherwise
        """
        try:
            self.s3_client.upload_file(local_file_path, self.bucket_name, s3_key)
            logger.info(f"Uploaded: {local_file_path} -> s3://{self.bucket_name}/{s3_key}")
            return True
        except FileNotFoundError:
            logger.error(f"Local file not found: {local_file_path}")
            return False
        except ClientError as e:
            logger.error(f"Error uploading file: {e}")
            return False
    
    def download_file(self, s3_key: str, local_file_path: str) -> bool:
        """
        Download a file from S3 bucket
        
        Args:
            s3_key: S3 object key (path in bucket)
            local_file_path: Path where to save the downloaded file
            
        Returns:
            bool: True if download successful, False otherwise
        """
        try:
            self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
            logger.info(f"Downloaded: s3://{self.bucket_name}/{s3_key} -> {local_file_path}")
            return True
        except ClientError as e:
            logger.error(f"Error downloading file: {e}")
            return False
    
    def list_objects(self, prefix: str = "") -> Optional[list]:
        """
        List objects in the S3 bucket
        
        Args:
            prefix: Prefix to filter objects
            
        Returns:
            list: List of object keys, None if error
        """
        try:
            response = self.s3_client.list_objects_v2(
                Bucket=self.bucket_name,
                Prefix=prefix
            )
            
            if 'Contents' in response:
                objects = [obj['Key'] for obj in response['Contents']]
                logger.info(f"Found {len(objects)} objects with prefix '{prefix}'")
                return objects
            else:
                logger.info(f"No objects found with prefix '{prefix}'")
                return []
                
        except ClientError as e:
            logger.error(f"Error listing objects: {e}")
            return None
    
    def delete_object(self, s3_key: str) -> bool:
        """
        Delete an object from S3 bucket
        
        Args:
            s3_key: S3 object key to delete
            
        Returns:
            bool: True if deletion successful, False otherwise
        """
        try:
            self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
            logger.info(f"Deleted: s3://{self.bucket_name}/{s3_key}")
            return True
        except ClientError as e:
            logger.error(f"Error deleting object: {e}")
            return False


def create_test_files() -> Dict[str, str]:
    """
    Create temporary test files for upload testing
    
    Returns:
        dict: Dictionary mapping file types to file paths
    """
    test_files = {}
    temp_dir = tempfile.mkdtemp()
    
    # Create a text file
    text_file = os.path.join(temp_dir, "test_data.txt")
    with open(text_file, 'w') as f:
        f.write("This is a test file for S3 upload/download testing.\n")
        f.write("It contains sample data for validation.\n")
    test_files['text'] = text_file
    
    # Create a JSON file
    json_file = os.path.join(temp_dir, "test_config.json")
    test_data = {
        "experiment_name": "s3_test",
        "parameters": {
            "learning_rate": 0.01,
            "batch_size": 64,
            "epochs": 100
        }
    }
    with open(json_file, 'w') as f:
        json.dump(test_data, f, indent=2)
    test_files['json'] = json_file
    
    logger.info(f"Created test files in: {temp_dir}")
    return test_files


def cleanup_test_files(test_files: Dict[str, str]):
    """
    Clean up temporary test files
    
    Args:
        test_files: Dictionary of test file paths
    """
    for file_path in test_files.values():
        try:
            if os.path.exists(file_path):
                os.remove(file_path)
            temp_dir = os.path.dirname(file_path)
            if os.path.exists(temp_dir) and not os.listdir(temp_dir):
                os.rmdir(temp_dir)
        except Exception as e:
            logger.warning(f"Could not clean up file {file_path}: {e}")


def run_s3_tests(bucket_name: str, region_name: str = 'eu-central-1',
                 aws_access_key_id: str = None, aws_secret_access_key: str = None,
                 aws_session_token: str = None) -> bool:
    """
    Run S3 bucket tests
    
    Args:
        bucket_name: Name of the S3 bucket to test
        region_name: AWS region name
        aws_access_key_id: AWS access key ID (optional)
        aws_secret_access_key: AWS secret access key (optional)
        aws_session_token: AWS session token (optional)
        
    Returns:
        bool: True if all tests pass, False otherwise
    """
    logger.info("Starting S3 bucket tests")
    
    tester = S3BucketTester(
        bucket_name=bucket_name,
        region_name=region_name,
        aws_access_key_id=aws_access_key_id,
        aws_secret_access_key=aws_secret_access_key,
        aws_session_token=aws_session_token
    )
    
    # Test 1: Setup connection
    if not tester.setup_s3_connection():
        return False
    
    # Test 2: Check bucket exists
    if not tester.check_bucket_exists():
        return False
    
    # Test 3: Create test files
    test_files = create_test_files()
    
    # Test 4: Upload files
    s3_keys = {}
    for file_type, file_path in test_files.items():
        s3_key = f"test_data/{file_type}_file_{os.path.basename(file_path)}"
        s3_keys[file_type] = s3_key
        if not tester.upload_file(file_path, s3_key):
            cleanup_test_files(test_files)
            return False
    
    # Test 5: List objects
    objects = tester.list_objects("test_data/")
    if objects is None:
        cleanup_test_files(test_files)
        return False
    
    # Test 6: Download files
    temp_download_dir = tempfile.mkdtemp()
    for file_type, s3_key in s3_keys.items():
        download_path = os.path.join(temp_download_dir, f"downloaded_{os.path.basename(s3_key)}")
        if not tester.download_file(s3_key, download_path):
            cleanup_test_files(test_files)
            return False
        
        # Verify downloaded file
        if not (os.path.exists(download_path) and os.path.getsize(download_path) > 0):
            logger.error(f"Downloaded file verification failed: {download_path}")
            cleanup_test_files(test_files)
            return False
    
    # Test 7: Clean up S3 objects
    for s3_key in s3_keys.values():
        if not tester.delete_object(s3_key):
            logger.warning(f"Failed to delete {s3_key}")
    
    # Clean up local files
    cleanup_test_files(test_files)
    
    # Clean up download directory
    try:
        for file in os.listdir(temp_download_dir):
            os.remove(os.path.join(temp_download_dir, file))
        os.rmdir(temp_download_dir)
    except Exception as e:
        logger.warning(f"Could not clean up download directory: {e}")
    
    logger.info("All tests completed successfully")
    return True


def main():
    """Main function to run S3 tests"""
    
    # Configuration
    BUCKET_NAME = "scipi1-public"  # Change this to your bucket name
    REGION_NAME = "eu-central-1"    # Central Europe region
    
    # IAM Credentials (optional - leave as None to use default credentials)
    AWS_ACCESS_KEY_ID = 'AKIAUEDAVTWQSZQ2WINM'  # <--- FILL HERE
    AWS_SECRET_ACCESS_KEY = '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' # <--- FILL HERE
    AWS_SESSION_TOKEN = None  # Leave to None
    
    # Alternative: Read from environment variables
    AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID', AWS_ACCESS_KEY_ID)
    AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY', AWS_SECRET_ACCESS_KEY)
    AWS_SESSION_TOKEN = os.getenv('AWS_SESSION_TOKEN', AWS_SESSION_TOKEN)
    
    print("S3 BUCKET TEST SCRIPT")
    print(f"Bucket: {BUCKET_NAME}")
    print(f"Region: {REGION_NAME}")
    print(f"Using explicit credentials: {AWS_ACCESS_KEY_ID is not None}")
    
    try:
        success = run_s3_tests(
            bucket_name=BUCKET_NAME,
            region_name=REGION_NAME,
            aws_access_key_id=AWS_ACCESS_KEY_ID,
            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
            aws_session_token=AWS_SESSION_TOKEN
        )
        if success:
            print("All S3 tests passed successfully!")
        else:
            print("Some S3 tests failed. Check the logs above.")
    except KeyboardInterrupt:
        print("Tests interrupted by user")
    except Exception as e:
        print(f"Unexpected error during testing: {e}")


if __name__ == "__main__":
    main()
