import unittest
import numpy as np
import os
import shutil
from src.data_loader import DataLoader, create_dummy_sensitive_data
from src.extractor import ActivationExtractor

class TestProject(unittest.TestCase):
    
    @classmethod
    def setUpClass(cls):
        # 
        if not os.path.exists("test_data"):
            os.makedirs("test_data")
        create_dummy_sensitive_data("test_data/test_ci.json")
        
        cls.model_name = "Qwen/Qwen2.5-0.5B-Instruct" 

    def test_data_loader(self):
        print("\n=== Testing Data Loader ===")
        loader = DataLoader()
        #
        sensitive = loader.load_sensitive_data("test_data/test_ci.json")
        self.assertTrue(len(sensitive) > 0)
        print(f"Loaded {len(sensitive)} sensitive samples.")
        
        try:
            benign = loader.load_benign_data(num_samples=5)
            self.assertEqual(len(benign), 5)
            print(f"Loaded {len(benign)} benign samples.")
        except:
            print("Skipping Alpaca download test due to network.")

    def test_extractor_shape(self):
        print("\n=== Testing Extractor Shape ===")
        # 
        try:
            extractor = ActivationExtractor(self.model_name, device="cpu", dtype="float32")
            dummy_text = "Hello world"
            # 
            vec = extractor.get_last_token_hidden_state(dummy_text, layer_idx=5)
            
            #
            print(f"Vector shape: {vec.shape}")
            self.assertEqual(len(vec.shape), 1) 
            self.assertTrue(vec.shape[0] > 0)
        except Exception as e:
            print(f"Skipping extractor test: {e}")

    @classmethod
    def tearDownClass(cls):
        if os.path.exists("test_data"):
            shutil.rmtree("test_data")

if __name__ == '__main__':
    unittest.main()