from typing import Tuple, List, Dict, Optional, Callable

import csv
import numpy as np
import os
import time
import gsw_test
import test_helper as helper
import kv


def generate_X(
    ns: List[int],
    d: int,
    num_trials: int,
    data_dir: str = "data",
):
    os.makedirs(data_dir, exist_ok=True)

    X_paths = helper.build_X_paths(ns, d, num_trials, data_dir)
    for n, X_file in zip(ns, X_paths):
        gsw_test.generate_X_dataset(
            num_datasets=num_trials,
            n=n,
            d=d,
            seed=0,
            out_path=X_file,
        )

def generate_assignments(
    ns: List[int],
    d: int,
    num_trials: int,
    ws: List[np.ndarray],
    phis: List[float],
    gsw_its: int = 1000,
    data_dir: str = "data",
    balance: bool = False,
    exp: bool = False,
    gsw_low_rank_approx: bool = False,
):
    # Note: MSOD requires Mosek; run under a compatible env (e.g. `conda activate mosek311`)
    os.makedirs(data_dir, exist_ok=True)
    X_paths = helper.build_X_paths(ns, d, num_trials, data_dir)
    gsw_runtime_rows: List[Tuple[int, float]] = []

    for i_n, n in enumerate(ns):
        
        X_file = X_paths[i_n]

        gsw_assign_paths = helper.build_gsw_assign_paths(ns, d, num_trials, ws, phis, data_dir, balance=balance, low_rank_approx=gsw_low_rank_approx)[
            i_n
        ]

        gsw_start_time = time.time()
        for (w, phi, gsw_assign_file) in gsw_assign_paths:
            gsw_test.generate_gsw_assignments_from_X_npz(
                input_path=X_file,
                w=np.array(w, dtype=float),
                phi=phi,
                gsw_its=gsw_its,
                out_path=gsw_assign_file,
                balance=balance,
                exp=exp,
                low_rank_approx=gsw_low_rank_approx,
            )
        gsw_runtime = time.time() - gsw_start_time
        print(f"n: {n}, gsw_time: {gsw_runtime:.3f}s")
        gsw_runtime_rows.append((n, gsw_runtime))

    gsw_runtime_path = os.path.join(
        data_dir, f"gsw_runtime_n{min(ns)}_to_{max(ns)}_num_trials{num_trials}.csv"
    )
    with open(gsw_runtime_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["n", "runtime_seconds"])
        for n_val, runtime in gsw_runtime_rows:
            writer.writerow([n_val, runtime])

        # PSOD assignments per kernel
        data = np.load(X_file)
        X_all = data["X"]
        num_X, n_x, d_x = X_all.shape
        if n_x != n or d_x != d:
            raise ValueError("Mismatch between X file and requested n,d.")

        kernels = {
            "lin": lambda x: kv.LinearKernel(x, normalize=False),
            "quad": lambda x: kv.PolynomialKernel(x, deg=2, normalize=False),
            "gaus": lambda x: kv.GaussianKernel(x, s=1.0),
            "exp": lambda x: kv.ExpKernel(x, normalize=False),
        }
        runtime_rows = [("gsw", "", gsw_runtime)]
        for kname, kfunc in kernels.items():
            start = time.time()
            assignments = np.empty((num_X, n), dtype=int)
            for i_x, X in enumerate(X_all):
                K = kfunc(X)
                assignments[i_x] = np.array(kv.PSOD(K), dtype=int)
            runtime_rows.append(("psod", kname, time.time() - start))
            print(f"psod_time of {kname}: {time.time() - start:.3f}s")

            psod_path = os.path.join(
                data_dir, f"psod_assignment_{kname}_n{n}_d{d}_num_trials{num_trials}.npz"
            )
            np.savez_compressed(
                psod_path,
                assignments=assignments,
                kernel=kname,
                n=n,
                d=d,
                source=X_file,
            )

        
        MSODKs = ["gaus", "exp"]
        boot = 100
        if kv.mosek_available:
            for kname in MSODKs:
                if kname not in kernels:
                    continue
                start = time.time()
                weights_list = []
                assignments_list = []
                for i_x, X in enumerate(X_all):
                    K = kernels[kname](X)
                    msod_result = kv.MSODHeuristic(K, boot)
                    if not isinstance(msod_result, tuple) or len(msod_result) != 2:
                        raise RuntimeError(f"MSODHeuristic failed for kernel {kname}")
                    t, us = msod_result
                    weights_list.append(np.array(t, dtype=float))
                    assignments_list.append(np.array(us, dtype=int))
                runtime_rows.append(("msod", kname, time.time() - start))
                print(f"msod_time of {kname}: {time.time() - start:.3f}s")

                weights_arr = np.stack(weights_list, axis=0)
                assignments_arr = np.stack(assignments_list, axis=0)
                msod_path = os.path.join(
                    data_dir, f"msod_assignment_{kname}_n{n}_d{d}_num_trials{num_trials}.npz"
                )
                np.savez_compressed(
                    msod_path,
                    weights=weights_arr,
                    assignments=assignments_arr,
                    kernel=kname,
                    boot=boot,
                    n=n,
                    d=d,
                    source=X_file,
                )
    
        else:
            for kname in MSODKs:
                runtime_rows.append(("msod_skipped_no_mosek", kname, ""))


        # record runtime
        runtime_path = os.path.join(
            data_dir, f"assignments_runtime_n{n}_d{d}_num_trials{num_trials}.csv"
        )
        with open(runtime_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["method", "kernel", "runtime_seconds"])
            for row in runtime_rows:
                writer.writerow(row)

        # generate classic assignments
        classic_assign_file = helper.build_classic_assign_path(n, d, num_trials, data_dir)
        gsw_test.generate_classic_assignments_from_X_npz(
            input_path=X_file,
            its=gsw_its,
            out_path=classic_assign_file,
        )


def evaluate(
    ns: List[int],
    d: int,
    num_trials: int,
    ws: List[np.ndarray],
    phis: List[float],
    sigma: float = 0.0,
    data_dir: str = "data",
    report_cov: bool = True,
    f0_funcs: Optional[Dict[str, Callable[[np.ndarray], float]]] = None,
) -> Tuple[
    Dict[str, Dict[int, float]],
    Dict[str, Dict[int, float]],
    Dict[int, Dict[str, Dict[str, float]]],
]:
    X_paths = helper.build_X_paths(ns, d, num_trials, data_dir)
    gsw_assign_paths_by_n = helper.build_gsw_assign_paths(
        ns, d, num_trials, ws, phis, data_dir
    )
    gsw_assign_paths_by_n_balance = helper.build_gsw_assign_paths(
        ns, d, num_trials, ws, phis, data_dir, balance=True
    )
    gsw_assign_paths_by_n_lowrank = helper.build_gsw_assign_paths(
        ns, d, num_trials, ws, phis, data_dir, low_rank_approx=True
    )
    gsw_assign_paths_by_n = [
        base + balanced + lowrank
        for base, balanced, lowrank in zip(
            gsw_assign_paths_by_n, gsw_assign_paths_by_n_balance, gsw_assign_paths_by_n_lowrank
        )
    ]
    

    cov_results: Dict[int, Dict[str, float]] = {}
    condvar_results: Dict[int, Dict[str, Dict[str, float]]] = {}

    for i_n, n in enumerate(ns):
        X_file = X_paths[i_n]

        gsw_assign_paths = [p for _, _, p in gsw_assign_paths_by_n[i_n]]
        classic_assign_file = helper.build_classic_assign_path(n, d, num_trials, data_dir)

        if report_cov:
            cov = gsw_test.evaluate_cov_from_saved(
                X_file,
                assignments_npz_paths=gsw_assign_paths,
                classic_assignments_npz_path=classic_assign_file,
            )
            cov_results[n] = cov

            # store cov to a .csv file
            os.makedirs(data_dir, exist_ok=True)
            cov_csv_path = os.path.join(data_dir, f"cov_results_n{n}.csv")
            with open(cov_csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                metrics = ["U", "gram_p1", "gram_p2", "gram_p3", "gram_p4", "gram_p5"]
                writer.writerow(["design"] + metrics)
                grouped: Dict[str, Dict[str, float]] = {}
                for key, value in cov.items():
                    for metric in metrics:
                        suffix = f"_{metric}"
                        if key.endswith(suffix):
                            base = key[: -len(suffix)]
                            grouped.setdefault(base, {})[metric] = value
                            break
                for base in sorted(grouped.keys()):
                    row = [base] + [grouped[base].get(metric, "") for metric in metrics]
                    writer.writerow(row)
            cov_norm_csv_path = os.path.join(data_dir, f"cov_results_n{n}_norm.csv")
            with open(cov_norm_csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                metrics = ["U", "gram_p1", "gram_p2", "gram_p3", "gram_p4", "gram_p5"]
                writer.writerow(["design"] + metrics)
                mins: Dict[str, float] = {}
                for metric in metrics:
                    values = [
                        grouped[base][metric]
                        for base in grouped
                        if metric in grouped[base]
                    ]
                    if values:
                        mins[metric] = min(values)
                for base in sorted(grouped.keys()):
                    row = [base]
                    for metric in metrics:
                        value = grouped[base].get(metric)
                        denom = mins.get(metric)
                        if value is None or denom in (None, 0.0):
                            row.append("")
                        else:
                            row.append(f"{value / denom:.2f}")
                    writer.writerow(row)
            cov_gsw_norm_csv_path = os.path.join(
                data_dir, f"cov_results_n{n}_gsw_phi0.5_norm.csv"
            )
            with open(cov_gsw_norm_csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                metrics = ["U", "gram_p1", "gram_p2", "gram_p3", "gram_p4", "gram_p5"]
                writer.writerow(["design"] + metrics)
                gsw_grouped: Dict[str, Dict[str, float]] = {}
                for key, value in cov.items():
                    is_gsw = key.startswith("gsw")
                    if is_gsw and "_phi0.5" not in key:
                        continue
                    for metric in metrics:
                        suffix = f"_{metric}"
                        if key.endswith(suffix):
                            base = key[: -len(suffix)]
                            gsw_grouped.setdefault(base, {})[metric] = value
                            break
                mins: Dict[str, float] = {}
                for metric in metrics:
                    values = [
                        gsw_grouped[base][metric]
                        for base in gsw_grouped
                        if metric in gsw_grouped[base]
                    ]
                    if values:
                        mins[metric] = min(values)
                for base in sorted(gsw_grouped.keys()):
                    row = [base]
                    for metric in metrics:
                        value = gsw_grouped[base].get(metric)
                        denom = mins.get(metric)
                        if value is None or denom in (None, 0.0):
                            row.append("")
                        else:
                            row.append(f"{value / denom:.2f}")
                    writer.writerow(row)
        
        if not report_cov:
            results = gsw_test.evaluate_gsw_condvar_from_saved(
                X_file,
                assignments_npz_paths=gsw_assign_paths,
                classic_assignments_npz_path=classic_assign_file,
                f0_funcs = f0_funcs,
                sigma=sigma,
            )
            condvar_results[n] = results

            # store results
            os.makedirs(data_dir, exist_ok=True)
            condvar_csv_path = os.path.join(data_dir, f"condvar_results_n{n}_sigma{sigma}.csv")
            design_names = sorted(
                {design for vals in results.values() for design in vals.keys()}
            )
            f0_names = sorted(results.keys())
            with open(condvar_csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["design"] + f0_names)
                for design in design_names:
                    row = [design] + [
                        results[fname].get(design, "") for fname in f0_names
                    ]
                    writer.writerow(row)
            
            condvar_norm_csv_path = os.path.join(
                data_dir, f"condvar_results_n{n}_sigma{sigma}_norm.csv"
            )
            with open(condvar_norm_csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["design"] + f0_names)
                mins: Dict[str, float] = {}
                for fname in f0_names:
                    values = [
                        results[fname].get(design)
                        for design in design_names
                        if results[fname].get(design) is not None
                    ]
                    if values:
                        mins[fname] = min(values)
                for design in design_names:
                    row = [design]
                    for fname in f0_names:
                        value = results[fname].get(design)
                        denom = mins.get(fname)
                        if value is None or denom in (None, 0.0):
                            row.append("")
                        else:
                            row.append(f"{value / denom:.2f}")
                    writer.writerow(row)
                    
            condvar_phi_csv_path = os.path.join(
                data_dir, f"condvar_results_n{n}_phi0.01_sigma{sigma}_norm.csv"
            )
            phi_designs = [
                design
                for design in design_names
                if (not design.startswith("gsw"))
                or (design.startswith("gsw") and design.endswith("phi0.01"))
            ]
            with open(condvar_phi_csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["design"] + f0_names)
                mins: Dict[str, float] = {}
                for fname in f0_names:
                    values = [
                        results[fname].get(design)
                        for design in phi_designs
                        if results[fname].get(design) is not None
                    ]
                    if values:
                        mins[fname] = min(values)
                for design in phi_designs:
                    row = [design]
                    for fname in f0_names:
                        value = results[fname].get(design)
                        denom = mins.get(fname)
                        if value is None or denom in (None, 0.0):
                            row.append("")
                        else:
                            row.append(f"{value / denom:.2f}")
                    writer.writerow(row)

    return cov_results, condvar_results


def _parse_int_list(value: str) -> List[int]:
    return [int(x) for x in value.split(",") if x.strip()]


def _parse_float_list(value: str) -> List[float]:
    return [float(x) for x in value.split(",") if x.strip()]


def _parse_ws(value: str) -> List[np.ndarray]:
    # Format: "0.5,0.5;1,0" -> [array([0.5,0.5]), array([1,0])]
    ws = []
    for vec in value.split(";"):
        vec = vec.strip()
        if not vec:
            continue
        ws.append(np.array(_parse_float_list(vec), dtype=float))
    return ws


def main() -> None:
    import argparse

    parser = argparse.ArgumentParser(description="Run assignment generation/evaluation.")
    parser.add_argument("--mode", choices=["generate", "evaluate"], required=True)
    parser.add_argument("--ns", default="10,15,20,25,30,35,40,45,50")
    parser.add_argument("--d", type=int, default=10)
    parser.add_argument("--num-trials", type=int, default=100)
    parser.add_argument("--phis", default="0.01, 0.1, 0.5, 0.9, 0.99")
    # parser.add_argument("--ws", default="1.0;1.0,0.5;1.0,0.5,0.5;1.0,1.0;1.0,1.0,1.0")
    parser.add_argument("--ws", default="1.0;1.0,0.5;1.0,0.5,0.5")
    parser.add_argument("--gsw-its", type=int, default=1000)    
    parser.add_argument("--data-dir", default="data")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--exp", action="store_true")
    parser.add_argument("--lowrank", action="store_true")
    parser.add_argument("--sigma", type=float, default=0.0)
    parser.add_argument("--no-report-cov", action="store_false")

    args = parser.parse_args()

    ns = _parse_int_list(args.ns)
    phis = _parse_float_list(args.phis)
    ws = _parse_ws(args.ws)

    if args.mode == "generate":
        generate_assignments(
            ns=ns,
            d=args.d,
            num_trials=args.num_trials,
            ws=ws,
            phis=phis,
            gsw_its=args.gsw_its,
            data_dir=args.data_dir,
            balance=args.balance,
            exp=args.exp,
            gsw_low_rank_approx=args.lowrank,
        )
    else:
        evaluate(
            ns=ns,
            d=args.d,
            num_trials=args.num_trials,
            ws=ws,
            phis=phis,
            sigma=args.sigma,
            data_dir=args.data_dir,
            report_cov=not args.no_report_cov,
        )


if __name__ == "__main__":
    main()
    
