"""
Verification script for PostgreSQL Task 2: Employee Retention Analysis
"""

import os
import sys
import psycopg2
from decimal import Decimal

def rows_match(actual_row, expected_row):
    """
    Compare two rows with appropriate tolerance.
    For Decimal types: allows 0.1 tolerance
    For other types: requires exact match
    """
    if len(actual_row) != len(expected_row):
        return False
    
    for actual, expected in zip(actual_row, expected_row):
        if isinstance(actual, Decimal) and isinstance(expected, Decimal):
            if abs(float(actual) - float(expected)) > 0.1:
                return False
        elif actual != expected:
            return False
    
    return True

def get_connection_params() -> dict:
    """Get database connection parameters."""
    return {
        "host": os.getenv("POSTGRES_HOST", "localhost"),
        "port": int(os.getenv("POSTGRES_PORT", 5432)),
        "database": os.getenv("POSTGRES_DATABASE"),
        "user": os.getenv("POSTGRES_USERNAME"),
        "password": os.getenv("POSTGRES_PASSWORD")
    }

def verify_retention_analysis_results(conn) -> bool:
    """Verify the employee retention analysis results."""
    with conn.cursor() as cur:
        # Get actual results from the created table
        cur.execute("""
            SELECT department_name, total_employees_ever, current_employees, 
                   former_employees, retention_rate
            FROM employees.employee_retention_analysis
            ORDER BY department_name
        """)
        actual_results = cur.fetchall()
        
        # Execute ground truth query
        cur.execute("""
            SELECT
            d.dept_name AS department_name,
            COUNT(DISTINCT de.employee_id) AS total_employees_ever,
            COUNT(DISTINCT de.employee_id) FILTER (WHERE de.to_date = DATE '9999-01-01') AS current_employees,
            (COUNT(DISTINCT de.employee_id)
            - COUNT(DISTINCT de.employee_id) FILTER (WHERE de.to_date = DATE '9999-01-01')) AS former_employees,
            (COUNT(DISTINCT de.employee_id) FILTER (WHERE de.to_date = DATE '9999-01-01'))::DECIMAL
                / NULLIF(COUNT(DISTINCT de.employee_id), 0) * 100 AS retention_rate
            FROM employees.department d
            LEFT JOIN employees.department_employee de
            ON d.id = de.department_id
            GROUP BY d.id, d.dept_name
            ORDER BY d.dept_name
        """)
        expected_results = cur.fetchall()

        if len(actual_results) != len(expected_results):
            print(f"❌ Expected {len(expected_results)} retention analysis results, got {len(actual_results)}")
            return False

        mismatches = 0
        for i, (actual, expected) in enumerate(zip(actual_results, expected_results)):
            if not rows_match(actual, expected):
                if mismatches < 5:  # Only show first 5 mismatches
                    print(f"❌ Row {i+1} mismatch: expected {expected}, got {actual}")
                mismatches += 1

        if mismatches > 0:
            print(f"❌ Total mismatches: {mismatches}")
            return False

        print(f"✅ Employee retention analysis results are correct ({len(actual_results)} records)")
        return True

def verify_high_risk_results(conn) -> bool:
    """Verify the high risk employee analysis results."""
    with conn.cursor() as cur:
        # Get actual results from the created table
        cur.execute("""
            SELECT employee_id, full_name, current_department, tenure_days, 
                   current_salary, risk_category
            FROM employees.high_risk_employees
            ORDER BY employee_id
        """)
        actual_results = cur.fetchall()
        
        # Execute ground truth query - only current employees
        cur.execute("""
            WITH current_salary AS (
            SELECT employee_id, amount AS current_amount
            FROM (
                SELECT s.*,
                    ROW_NUMBER() OVER (PARTITION BY s.employee_id
                                        ORDER BY s.from_date DESC, s.amount DESC) AS rn
                FROM employees.salary s
                WHERE s.to_date = DATE '9999-01-01'
            ) x
            WHERE rn = 1
            ),
            current_dept AS (
            SELECT employee_id, department_id
            FROM (
                SELECT de.*,
                    ROW_NUMBER() OVER (PARTITION BY de.employee_id
                                        ORDER BY de.from_date DESC, de.department_id) AS rn
                FROM employees.department_employee de
                WHERE de.to_date = DATE '9999-01-01'
            ) x
            WHERE rn = 1
            ),
            dept_retention AS (
            SELECT
                d.id   AS department_id,
                d.dept_name,
                COUNT(DISTINCT de.employee_id) AS total_employees_ever,
                COUNT(DISTINCT de.employee_id) FILTER (WHERE de.to_date = DATE '9999-01-01') AS current_employees,
                (COUNT(DISTINCT de.employee_id) FILTER (WHERE de.to_date = DATE '9999-01-01'))::NUMERIC
                / NULLIF(COUNT(DISTINCT de.employee_id), 0) * 100 AS retention_rate
            FROM employees.department d
            LEFT JOIN employees.department_employee de
                    ON de.department_id = d.id
            GROUP BY d.id, d.dept_name
            )
            SELECT
            e.id AS employee_id,
            CONCAT(e.first_name, ' ', e.last_name) AS full_name,
            d.dept_name AS current_department,
            (CURRENT_DATE - e.hire_date)::INTEGER AS tenure_days,
            cs.current_amount::INTEGER AS current_salary,
            CASE
                WHEN dr.retention_rate < 80  AND (CURRENT_DATE - e.hire_date) < 1095 THEN 'high_risk'
                WHEN dr.retention_rate < 85  AND (CURRENT_DATE - e.hire_date) < 1825 THEN 'medium_risk'
                ELSE 'low_risk'
            END AS risk_category
            FROM employees.employee e
            JOIN current_salary cs ON cs.employee_id = e.id
            JOIN current_dept   cd ON cd.employee_id = e.id
            JOIN employees.department d ON d.id = cd.department_id
            JOIN dept_retention dr ON dr.department_id = d.id
            ORDER BY e.id;
        """)
        expected_results = cur.fetchall()

        if len(actual_results) != len(expected_results):
            print(f"❌ Expected {len(expected_results)} high risk analysis results, got {len(actual_results)}")
            return False

        mismatches = 0
        for i, (actual, expected) in enumerate(zip(actual_results, expected_results)):
            if not rows_match(actual, expected):
                if mismatches < 5:  # Only show first 5 mismatches
                    print(f"❌ Row {i+1} mismatch: expected {expected}, got {actual}")
                mismatches += 1

        if mismatches > 0:
            print(f"❌ Total mismatches: {mismatches}")
            return False

        print(f"✅ High risk employee analysis results are correct ({len(actual_results)} records)")
        return True

def verify_turnover_trend_results(conn) -> bool:
    """Verify the turnover trend analysis results."""
    with conn.cursor() as cur:
        # Get actual results from the created table
        cur.execute("""
            SELECT departure_year, departures_count, avg_tenure_days, avg_final_salary
            FROM employees.turnover_trend_analysis
            ORDER BY departure_year
        """)
        actual_results = cur.fetchall()
        
        # Execute ground truth query - simplified version
        cur.execute("""
            WITH last_non_current_salary AS (
            SELECT
                s.employee_id,
                s.to_date      AS departure_date,
                s.amount       AS final_salary,
                ROW_NUMBER() OVER (
                PARTITION BY s.employee_id
                ORDER BY s.to_date DESC, s.from_date DESC, s.amount DESC
                ) AS rn
            FROM employees.salary s
            WHERE s.to_date <> DATE '9999-01-01'
                AND NOT EXISTS (
                SELECT 1
                FROM employees.salary s_cur
                WHERE s_cur.employee_id = s.employee_id
                    AND s_cur.to_date = DATE '9999-01-01'
                )
            ),
            departed AS (
            SELECT employee_id, departure_date, final_salary
            FROM last_non_current_salary
            WHERE rn = 1
            ),
            with_tenure AS (
            SELECT
                e.id AS employee_id,
                d.departure_date,
                d.final_salary,
                (d.departure_date - e.hire_date)::INTEGER AS tenure_days
            FROM employees.employee e
            JOIN departed d ON d.employee_id = e.id
            )
            SELECT
            EXTRACT(YEAR FROM departure_date)::INTEGER AS departure_year,
            COUNT(*)::INTEGER                         AS departures_count,
            AVG(tenure_days)                          AS avg_tenure_days,
            AVG(final_salary)                         AS avg_final_salary
            FROM with_tenure
            WHERE departure_date BETWEEN DATE '1985-01-01' AND DATE '2002-12-31'
            GROUP BY EXTRACT(YEAR FROM departure_date)
            ORDER BY departure_year;
        """)
        expected_results = cur.fetchall()

        if len(actual_results) != len(expected_results):
            print(f"❌ Expected {len(expected_results)} turnover trend results, got {len(actual_results)}")
            return False

        mismatches = 0
        for i, (actual, expected) in enumerate(zip(actual_results, expected_results)):
            if not rows_match(actual, expected):
                if mismatches < 5:  # Only show first 5 mismatches
                    print(f"❌ Row {i+1} mismatch: expected {expected}, got {actual}")
                mismatches += 1

        if mismatches > 0:
            print(f"❌ Total mismatches: {mismatches}")
            return False

        print(f"✅ Turnover trend analysis results are correct ({len(actual_results)} records)")
        return True

def main():
    """Main verification function."""
    print("=" * 50)

    # Get connection parameters
    conn_params = get_connection_params()

    if not conn_params["database"]:
        print("❌ No database specified")
        sys.exit(1)

    try:
        # Connect to database
        conn = psycopg2.connect(**conn_params)

        # Verify all three analysis results
        success = (
            verify_retention_analysis_results(conn) and 
            verify_high_risk_results(conn) and 
            verify_turnover_trend_results(conn)
        )

        conn.close()

        if success:
            print("\n🎉 Task verification: PASS")
            sys.exit(0)
        else:
            print("\n❌ Task verification: FAIL")
            sys.exit(1)

    except psycopg2.Error as e:
        print(f"❌ Database error: {e}")
        sys.exit(1)
    except Exception as e:
        print(f"❌ Verification error: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()