import argparse
import logging
import os
import random
import csv
from functools import partial

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from time import perf_counter as tpc

# --- Assuming these are importable from your source library ---
import teneva
from _src import (
    Protes, TensorFactorization, set_logger, get_map,
    DiabetesObjective, PressureVesselObjective, ConstraintWarcraft, 
    WarcraftObjectiveTF as WarcraftObjective, EggholderTF as Eggholder, 
    AckleyTF as Ackley, GAP_A_Objective, GAP_B_Objective,
    Ising_A_Objective, Ising_B_Objective, TSSObjective, SSSObjective
)


# ==============================================================================
# Main Runner Function
# ==============================================================================
def run_protes_optimization(settings):
    random.seed(settings['seed']); np.random.seed(settings['seed']); torch.manual_seed(settings['seed'])
    function_name = settings["function"]
    logging.info(f"Setting up objective function: {function_name}")
    
    objective_function = None; d = None; n = None; P_init = None; categories = None

    def batch_wrapper(I_batch, obj_func, **kwargs):
        scores = []
        for i_indices in I_batch:
            if 'categories' in kwargs and kwargs['categories'] is not None:
                decoded_input = [kwargs['categories'][i] for i in i_indices]
                if 'map_shape' in kwargs:
                    decoded_input = np.array(decoded_input).reshape(kwargs['map_shape'])
            else:
                decoded_input = i_indices
            if hasattr(obj_func, 'evaluate'): scores.append(obj_func.evaluate(decoded_input))
            else: scores.append(obj_func(decoded_input))
        return np.array(scores)

    if function_name == "warcraft":
        weight_matrix = get_map(settings["map_option"]); map_shape = weight_matrix.shape
        objective_function = WarcraftObjective(weight_matrix=weight_matrix); d = map_shape[0] * map_shape[1]
        categories = ["oo", "ab", "ac", "ad", "bc", "bd", "cd"]; n = len(categories)
        wrapper_func = partial(batch_wrapper, obj_func=objective_function, categories=categories, map_shape=map_shape)
    elif function_name == "diabetes":
        objective_function = DiabetesObjective(seed=settings["seed"]); d = len(objective_function.features)
        n = 5; wrapper_func = partial(batch_wrapper, obj_func=objective_function)
    elif function_name == "pressure":
        objective_function = PressureVesselObjective(seed=settings["seed"]); d = len(objective_function.features)
        n = 10; wrapper_func = partial(batch_wrapper, obj_func=objective_function)
        categories = objective_function.mid_points
    elif function_name in ["eggholder", "ackley"]:
        ObjectiveClass = Eggholder if function_name == "eggholder" else Ackley
        objective_function = ObjectiveClass(constrain=settings["constraint"]); d = 2
        bounds = objective_function.bounds; categories = list(range(bounds[0], bounds[1] + 1)); n = len(categories)
        wrapper_func = partial(batch_wrapper, obj_func=objective_function, categories=categories)
    elif function_name in ["gap_a", "gap_b"]:
        ObjectiveClass = GAP_A_Objective if function_name == "gap_a" else GAP_B_Objective
        objective_function = ObjectiveClass(); d = len(objective_function.features); n = objective_function.n_bins
        wrapper_func = partial(batch_wrapper, obj_func=objective_function)
    elif function_name in ["ising_a", "ising_b"]:
        ObjectiveClass = Ising_A_Objective if function_name == "ising_a" else Ising_B_Objective
        objective_function = ObjectiveClass(); d = len(objective_function.features); n = 2
        wrapper_func = partial(batch_wrapper, obj_func=objective_function)
    elif function_name == "tss":
        objective_function = TSSObjective(is_constrained=settings["constraint"]); d = 6
        categories = objective_function.operations; n = len(categories)
        wrapper_func = partial(batch_wrapper, obj_func=objective_function, categories=categories)
    elif function_name == "sss":
        objective_function = SSSObjective(is_constrained=settings["constraint"]); d = len(objective_function.features)
        categories = objective_function.channel_options; n = len(categories)
        wrapper_func = partial(batch_wrapper, obj_func=objective_function, categories=categories)
    else: raise ValueError(f"Unsupported function type: {function_name}")
    logging.info(f"Search space configured: d={d}, n={n}")

    if settings["constraint"]:
        if hasattr(objective_function, '_tensor_constraint') and objective_function._tensor_constraint is not None:
            try:
                logging.info("\n--- Phase 1: Building and Learning Constraint Tensor P ---")
                full_constraint_tensor_np = objective_function._tensor_constraint.astype(np.float32)
                full_constraint_tensor_torch = torch.from_numpy(full_constraint_tensor_np)
                logging.info(f"Learning TT-decomposition for constraint tensor of shape {full_constraint_tensor_torch.shape}...")
                tf_rank_approx = 10 if d < 10 else 20
                tf_model = TensorFactorization(tensor=full_constraint_tensor_torch, rank=tf_rank_approx, method='train')
                tf_model.optimize(max_iter=3000, mse_tol=1e-4, lr=0.01)
                logging.info(f"Learning finished. Final MSE: {getattr(tf_model, 'mse_loss', torch.tensor(0)).item():.6f}")
                learned_cores_torch = tf_model.get_state(); P_init = [core.cpu().numpy() for core in learned_cores_torch]
            except MemoryError: logging.error("MemoryError... Skipping constrained optimization."); return
            except Exception as e: logging.error(f"Error in constraint prep: {e}"); return
        else: logging.warning(f"Constraint=True but no tensor found for '{function_name}'. Running unconstrained.")
        
    logging.info("\n--- Phase 2: Running PROTES Optimization ---")
    protes_optimizer = Protes(f=wrapper_func, d=d, n=n, is_max=settings["direction"], r=settings["protes_settings"]["rank"], results_dir=settings["results_dir"], results_prefix=str(settings['seed']))
    i_opt, y_opt = protes_optimizer.optimize(m=settings["iter_bo"], k=settings["protes_settings"]["k_batch"], k_top=settings["protes_settings"]["k_top"], k_gd=settings["protes_settings"]["k_gd"], lr=settings["protes_settings"]["lr"], P_init=P_init, log=True)

    logging.info(f"\n--- Optimization Finished for {function_name} ---"); logging.info(f"Best objective value found: {y_opt:.6f}")
    if categories is not None and function_name not in ["pressure"]: best_params_decoded = [categories[i] for i in i_opt]
    elif function_name == "pressure": best_params_decoded = [categories[i][idx] for i, idx in enumerate(i_opt)]
    else: best_params_decoded = i_opt.tolist()
    logging.info(f"Best parameters (indices): {i_opt.tolist()}")
    if function_name not in ["diabetes", "gap_a", "gap_b", "ising_a", "ising_b"]: logging.info(f"Best parameters (decoded): {best_params_decoded}")
    if function_name == "warcraft": objective_function.visualize(np.array(best_params_decoded).reshape(map_shape))
    
    results_filepath = os.path.join(settings["results_dir"], f"{settings['name']}_results.csv")
    with open(results_filepath, 'w', newline='') as f:
        writer = csv.writer(f); writer.writerow(['best_value', 'best_params_indices', 'best_params_decoded'])
        writer.writerow([y_opt, i_opt.tolist(), best_params_decoded])
    logging.info(f"Results saved to {results_filepath}")

def parse_args():
    parser = argparse.ArgumentParser(description="Unified Benchmark for PROTES")
    parser.add_argument("--timestamp", type=str, help="Timestamp for the experiment")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--iter_bo", type=int, default=10000, help="Number of PROTES iterations (budget)")
    parser.add_argument("--function", type=str, required=True, choices=["diabetes", "pressure", "warcraft", "eggholder", "ackley", "gap_a", "gap_b", "ising_a", "ising_b", "tss", "sss"], help="Objective function to run.")
    parser.add_argument("--constraint", action="store_true", help="Use constraint in the objective function")
    parser.add_argument("--direction", action="store_true", help="Maximize the objective function (default is minimize)")
    parser.add_argument("--map_option", type=int, choices=[1, 2, 3], default=1, help="Map option for Warcraft")
    parser.add_argument("--protes_rank", type=int, default=5, help="TT-rank for PROTES model")
    parser.add_argument("--protes_k_batch", type=int, default=100, help="Batch size (k) for PROTES")
    parser.add_argument("--protes_k_top", type=int, default=10, help="Elite samples (k_top) for PROTES")
    parser.add_argument("--protes_k_gd", type=int, default=1, help="Gradient steps for PROTES")
    parser.add_argument("--protes_lr", type=float, default=0.01, help="Learning rate for PROTES")
    parser.add_argument("--base_dir", type=str, default="results_protes")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    timestamp = args.timestamp if args.timestamp else pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
    function_subdir = args.function
    if args.function == "warcraft":
        function_subdir = f"warcraft_map{args.map_option}"
    results_dir = os.path.join(args.base_dir, timestamp, function_subdir)
    os.makedirs(results_dir, exist_ok=True)
    log_filename_base = f"run_seed{args.seed}"
    if args.constraint: log_filename_base += "_constrained"
    log_filepath = set_logger(log_filename_base, results_dir)
    settings = {
        "name": f"{timestamp}_{log_filename_base}", "seed": args.seed, "function": args.function,
        "constraint": args.constraint, "iter_bo": args.iter_bo, "results_dir": results_dir,
        "protes_settings": {
            "rank": args.protes_rank, "k_batch": args.protes_k_batch, "k_top": args.protes_k_top,
            "k_gd": args.protes_k_gd, "lr": args.protes_lr,
        },
    }
    if args.function in ["gap_a", "gap_b"]: settings["direction"] = True
    else: settings["direction"] = args.direction
    if args.function == "warcraft": settings["map_option"] = args.map_option
    logging.info(f"Experiment settings: {settings}")
    run_protes_optimization(settings)