------
## 最终结论
```python
def can_partition(nums, k):
    total_sum = sum(nums)
    if total_sum % k != 0:
        return False
    
    target_sum = total_sum // k
    used = [False] * len(nums)

    def backtrack(start_index, current_sum, count):
        if count == k - 1:  # If k-1 subsets are done, the last one is automatically valid
            return True
        if current_sum == target_sum:  # Found a valid subset
            return backtrack(0, 0, count + 1)  # Start a new subset
        for i in range(start_index, len(nums)):
            if not used[i] and current_sum + nums[i] <= target_sum:
                used[i] = True
                if backtrack(i + 1, current_sum + nums[i], count):
                    return True
                used[i] = False  # Backtrack
        return False

    nums.sort(reverse=True)  # Sort in descending order to optimize performance
    return backtrack(0, 0, 0)

# Example usage:
nums = [4, 3, 2, 3, 5, 2, 1]
k = 4
print(can_partition(nums, k))  # Output: True
```