"""
DIAGNOSTIC: Find why rewards are going to 0.0
"""

import numpy as np
import torch
import random
import sys

sys.path.insert(0, '.')

from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import FixedDecomposedGameOptUCBHedge
from MOCO.problems import BiObjectiveTSP

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Create problem
problem = BiObjectiveTSP(n_cities=20)

def evaluate(sol):
    """Evaluate with validation."""
    # Check if valid permutation
    if len(sol) != 20:
        print(f"  ERROR: Solution length {len(sol)} != 20")
        return -1e10
    
    if len(set(sol)) != 20:
        print(f"  ERROR: Solution has duplicates: {sorted(sol)}")
        return -1e10
    
    if not all(0 <= x < 20 for x in sol):
        print(f"  ERROR: Invalid values in solution: {sol}")
        return -1e10
    
    o1, o2 = problem.evaluate(sol)
    reward = -(0.5 * o1 + 0.5 * o2)
    
    # TSP distance should never be 0 or positive
    if reward >= 0:
        print(f"  WARNING: reward={reward:.4f}, o1={o1:.4f}, o2={o2:.4f}")
        print(f"  Solution: {sol}")
    
    return reward

print("="*60)
print("DIAGNOSTIC: Testing your optimizer")
print("="*60)

# Create optimizer
opt = FixedDecomposedGameOptUCBHedge(
    problem_size=20,
    evaluate_fn=evaluate,
    decomposition_size=15,
    overlap=6,
    ucb_coefficient=3.0,
    learning_rate=0.5,
    item_weights=None,  # Force permutation mode
)

print(f"\nmax_iterations: {opt.max_iterations}")
print(f"subproblems: {opt.subproblems}")

# Test initial solution construction
print("\n--- Testing construct_solution_with_decomposition ---")
sol1 = opt.construct_solution_with_decomposition()
print(f"Solution: {sol1}")
print(f"Valid permutation: {len(set(sol1)) == 20 and all(0 <= x < 20 for x in sol1)}")
r1 = evaluate(sol1)
print(f"Reward: {r1:.4f}")

# Test improve_subproblem manually
print("\n--- Testing improve_subproblem on subproblem 0 ---")
print(f"Subproblem 0: {opt.subproblems[0]}")
improved_sol, improvement = opt.improve_subproblem(sol1, opt.subproblems[0])
print(f"Improved solution: {improved_sol}")
print(f"Valid permutation: {len(set(improved_sol)) == 20}")
print(f"Improvement: {improvement:.4f}")
r2 = evaluate(improved_sol)
print(f"Reward: {r2:.4f}")

# Test all subproblems
print("\n--- Testing improve_subproblem on ALL subproblems ---")
current_sol = sol1.copy()
for sp_idx, sp in enumerate(opt.subproblems):
    print(f"\nSubproblem {sp_idx}: {sp}")
    improved_sol, improvement = opt.improve_subproblem(current_sol, sp)
    
    # Validate
    is_valid = len(set(improved_sol)) == 20 and all(0 <= x < 20 for x in improved_sol)
    print(f"  Valid: {is_valid}")
    
    if not is_valid:
        print(f"  BROKEN SOLUTION: {improved_sol}")
        print(f"  Duplicates: {[x for x in improved_sol if improved_sol.count(x) > 1]}")
        print(f"  Missing: {set(range(20)) - set(improved_sol)}")
    
    r = evaluate(improved_sol)
    print(f"  Reward: {r:.4f}, Improvement: {improvement:.4f}")
    
    if improvement > 0:
        current_sol = improved_sol
        print(f"  -> Accepted")

print("\n--- Check sensitivity_learning ---")
try:
    from sensitivity_learning import create_problem_agnostic_improver
    print("sensitivity_learning imported successfully")
    
    # Test the improver directly
    improver = create_problem_agnostic_improver("TSP", 20)
    print(f"Improver function: {improver}")
    
    # Test it
    test_sol = list(range(20))
    random.shuffle(test_sol)
    print(f"\nTest solution: {test_sol}")
    
    result = improver(opt, test_sol, [0, 1, 2, 3, 4])
    print(f"Result: {result}")
    print(f"Result type: {type(result)}")
    
    if isinstance(result, tuple):
        imp_sol, imp = result
        print(f"Improved solution: {imp_sol}")
        print(f"Valid: {len(set(imp_sol)) == 20}")
        
except Exception as e:
    print(f"Error with sensitivity_learning: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("DIAGNOSTIC COMPLETE")
print("="*60)