import argparse
import os
import numpy as np
from ast import literal_eval
from recourse.utils import (
    load_cost_matrix,
    generate_random_capacities
)
from recourse.matching import solve_initial_matching
from recourse.best_distribution import optimal_capacity
from recourse.cost_aware import solve_cost_aware_matching
from utils import generate_welfare_plot, save_matching_output, get_result_folder


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--folder', required=True, help='Path to dataset folder')
    parser.add_argument('-a', '--approach', choices=['eq1', 'eq2', 'eq3', 'all'], default='all')
    parser.add_argument('-g', '--gamma', type=float, default=10, help='Gamma value for weight function')
    parser.add_argument('-b', '--beta', type=float, default=0.03, help='Beta value for capacity change penalty')
    parser.add_argument('-s', '--seed', type=int, default=20, help='Random seed')
    parser.add_argument('--K', type=int, default=None, help='Total capacity (optional)')
    parser.add_argument('--init-cap', type=str, default=None, help='Initial capacity dict as string, e.g., "{0:2, 1:3}"')
    args = parser.parse_args()

    if not os.path.exists(args.folder):
        raise FileNotFoundError(f"The specified folder does not exist: {args.folder}")
    
    result_dir = get_result_folder(args.approach, args.gamma, args.beta, args.K, args.seed)

    costs, weights = load_cost_matrix(args.folder, gamma=args.gamma)
    n, m = weights.shape

    if args.init_cap:
        try:
            init_capacity = literal_eval(args.init_cap)
            K = sum(init_capacity.values())
            assert isinstance(init_capacity, dict), "Initial capacity must be a dictionary"
        except Exception as e:
            raise ValueError(f"Invalid --init-cap format: {e}")
    elif args.K is not None:
        init_capacity = generate_random_capacities(args.K, m, seed=args.seed)
        K = args.K
    else:
        init_capacity = generate_random_capacities(n, m, seed=args.seed)
        K = n

    init_cap_vector = np.array([init_capacity[j] for j in range(m)])

    if args.approach in ['eq1', 'all']:
        print("▶ Running Equation (1)...")
        model1, z1 = solve_initial_matching(weights, init_cap_vector)
        save_matching_output(model1, z1, costs, weights, init_cap_vector, os.path.join(result_dir, "eq1_results.txt"))

    if args.approach in ['eq2', 'all']:
        print("▶ Running Equation (2)...")
        best_dist, IW, model2, z2 = optimal_capacity(weights, K)
        
        IW_range = []
        max_weights = [row.max() for row in weights]
        for k in range(0,n*m+1):
            IW_k = sorted(max_weights, reverse=True)[:k]
            IW_range.append(sum(IW_k))
            
        save_matching_output(model2, z2, costs, weights, best_dist, os.path.join(result_dir, "eq2_results.txt"))
        generate_welfare_plot(IW_range, IW, n, m, os.path.join(result_dir, "eq2_welfare_plot.png"))

    if args.approach in ['eq3', 'all']:
        print("▶ Running Equation (3)...")
        model3, z3, c3 = solve_cost_aware_matching(weights, init_cap_vector, beta=args.beta)
        vector_c3 = [int(c3[j].X) for j in range(m)]
        save_matching_output(model3, z3, costs, weights, vector_c3, os.path.join(result_dir, "eq3_results.txt"))

if __name__ == "__main__":
    main()
