#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import csv
import glob
import sys
import argparse
import subprocess
import time
import random

import torch
import numpy as np
from rich import pretty
from rich.console import Console
from rich.table import Table

from main import main  

pretty.install()


GPU_IDS_DEFAULT = [0, 1, 2, 3]   
EPOCHS_DEFAULT  = 70
BASE_ROOT_DEFAULT = "./ablation/BAG_net_bash1_ori"



DOMAIN_NAMES = ["A", "C", "P", "S"]  # index: 0 1 2 3


BEST_PARAMS_BY_DOMAIN = {
    "P": {  
        "batch_size": 128,
        "lr": 6.005992052862356e-05,
        "z_c_dim": 128,
        "z_s_dim": 80,
        "reg_lambda": 0.119856609944216,
        "hide_dim": 336,
        "train_z_s_epoch": 5,
        "z_s_classfier_lr": 0.0004704000369667202,
        "vae_lambda": 5.0894969126760786e-05,
        "epochs": 70,
    },
    "A": {  
        "batch_size": 256,
        "lr": 0.00010361384605114215,
        "z_c_dim": 112,
        "z_s_dim": 112,
        "reg_lambda": 0.22879306584785106,
        "hide_dim": 352,
        "train_z_s_epoch": 4,
        "z_s_classfier_lr": 0.00030321518607435483,
        "vae_lambda": 1.0322670193656338e-05,
        "epochs": 70,
    },
    "C": {  
        "batch_size": 64,
        "lr": 0.00011326674880985933,
        "z_c_dim": 48,
        "z_s_dim": 128,
        "reg_lambda": 0.13569103550679992,
        "hide_dim": 240,
        "train_z_s_epoch": 6,
        "z_s_classfier_lr": 0.00036170622598554424,
        "vae_lambda": 3.8494870638899416e-05,
        "epochs": 70,
    },
    "S": {  
        "batch_size": 128,
        "lr": 0.00021640893021461794,
        "z_c_dim": 48,
        "z_s_dim": 32,
        "reg_lambda": 0.0825803882070895,
        "hide_dim": 112,
        "train_z_s_epoch": 3,
        "z_s_classfier_lr": 0.015244631058719038,
        "vae_lambda": 1.8201858790104406e-05,
        "epochs": 70,
    },
}

PARAM_KEY_ORDER = [
    "batch_size", "lr", "z_c_dim", "z_s_dim", "reg_lambda",
    "hide_dim", "train_z_s_epoch", "z_s_classfier_lr", "vae_lambda", "epochs"
]


def get_arg_parser():
    p = argparse.ArgumentParser()
    p.add_argument('--n_envs', type=int, default=5)
    p.add_argument('--irm_reg_lambda', type=float, default=52.98316906283707)
    p.add_argument('--phi_odim',  type=int, default=3)
    p.add_argument('--fine_tune_lr',  type=float, default=1e-4)
    p.add_argument('--n_finetune_loop',  type=int, default=20)

    p.add_argument('--model_name', type=str, default="adp_invar_anti_causal")
    p.add_argument('--compare_all_invariant_models', action='store_true')
    p.add_argument('--classification', type=bool, default=True)

    p.add_argument('--run_fine_tune_test', action='store_true')
    p.add_argument('--n_fine_tune_tests', type=int, default=10)
    p.add_argument('--n_fine_tune_points', nargs='+', type=int)

    p.add_argument('--dataset', type=str, default="pacs")
    p.add_argument('--causal_dir_syn', type=str, default="anti")
    p.add_argument('--syn_dataset_train_size', type=int, default=1024)
    p.add_argument('--downsample', action='store_true')
    p.add_argument('--data_dir', type=str, default="")
    p.add_argument('--run_fine_tune_test_standalone', action='store_true')
    p.add_argument('--print_base_graph', action='store_true')
    p.add_argument('--verbose', action='store_true')
    p.add_argument('--cvs_dir', type=str, default="./test.cvs")
    p.add_argument('--model_save_dir', type=str, default="")
    p.add_argument('--hyper_param_tuning', action='store_true')
    p.add_argument('--save_test_phi', action='store_true')
    p.add_argument('--nb_workers', type=int, default=16)
    p.add_argument('--random_seed', type=int, default=42)
    p.add_argument('--balanced_dataset', action='store_true')
    p.add_argument('--maml_only', action='store_true')
    p.add_argument('--disentagnle_plot', action='store_true')

    p.add_argument('--val_index', type=int, default=1)
    p.add_argument('--z_dim', type=int, default=96)
    p.add_argument('--resnet_dim', type=int, default=48)
    p.add_argument('--environment_num', type=int, default=3)
    p.add_argument('--environment_dim', type=int, default=10)
    p.add_argument('--gamma', type=float, default=0.9)
    p.add_argument('--reg_lambda_2', type=float, default=10)
    p.add_argument('--C_max', type=float, default=20., metavar='N')
    p.add_argument('--C_stop_iter', type=int, default=30, metavar='N')
    p.add_argument('--beta', type=float, default=1.)

    p.add_argument('--batch_size', type=int, default=128)
    p.add_argument('--lr',  type=float, default=1e-4)
    p.add_argument("--z_c_dim", type=int, default=128)
    p.add_argument("--z_s_dim", type=int, default=128)
    p.add_argument('--reg_lambda', type=float, default=0.1)
    p.add_argument('--hide_dim', type=int, default=512)
    p.add_argument('--train_z_s_epoch', type=int, default=5)
    p.add_argument('--z_s_classfier_lr', type=float, default=1e-3)
    p.add_argument('--vae_lambda', type=float, default=1e-3)
    p.add_argument('--tune', action='store_true')

    p.add_argument('--n_outer_loop',  type=int, default=30)
    p.add_argument('--tune_n_trials', type=int, default=30)

    p.add_argument('--test_index', type=int, default=2)
    p.add_argument('--save_dir_name', type=str, default="P23")
    p.add_argument('--load_csv', type=str, default="")

    p.add_argument('--worker', action='store_true', help='[internal] run a single domain job')
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--base_root', type=str, default=BASE_ROOT_DEFAULT)
    p.add_argument('--epochs', type=int, default=EPOCHS_DEFAULT)
    p.add_argument('--gpu_ids', type=str, default="0,1,2,3", help='preferred GPU indices, e.g., 0,1,2,3')
    p.add_argument('--device_index', type=int, default=0, help='[internal] which GPU index this worker should use')
    return p





def parse_gpu_ids(s: str):
    return [int(x) for x in s.split(",") if x.strip() != ""]

def apply_params_for_domain(args, domain_letter: str):

    params = BEST_PARAMS_BY_DOMAIN[domain_letter]
    for k, v in params.items():
        setattr(args, k, v)
    args.z_dim = args.z_c_dim + args.z_s_dim
    args.data = "pacs"
    args.domain_names = DOMAIN_NAMES
    args.epochs = getattr(args, "epochs", EPOCHS_DEFAULT)
    return params


def run_one_domain(domain_index: int, seed: int, base_root: str, device_index: int) -> float:
    import datetime

    parser = get_arg_parser()
   
    args = parser.parse_args([])

    domain_letter = DOMAIN_NAMES[domain_index]
    used_params = apply_params_for_domain(args, domain_letter)

    args.base_root  = base_root
    args.test_index = int(domain_index)
    args.epochs     = used_params.get("epochs", getattr(args, "epochs", EPOCHS_DEFAULT))
    args.test_name  = f"{domain_letter}{seed}"
    args.log        = os.path.join(args.base_root, args.test_name, str(seed))
    args.device_index = device_index 


    args.random_seed = seed
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


    try:
        if torch.cuda.is_available() and torch.cuda.device_count() > device_index:
            torch.cuda.set_device(device_index)
            bound_name = torch.cuda.get_device_name(device_index)
        else:
            bound_name = "cpu"
    except Exception as e:
        bound_name = f"set_device_failed: {e}"

    cnt = torch.cuda.device_count() if torch.cuda.is_available() else 0
    cur = torch.cuda.current_device() if (torch.cuda.is_available() and cnt > 0) else None
    print(f"[Worker] PID={os.getpid()} | Domain={domain_letter} | Seed={seed} | Start={datetime.datetime.now().isoformat(timespec='seconds')}")
    print(f"[Worker] torch.cuda.is_available={torch.cuda.is_available()} | device_count={cnt} | "
          f"requested_device_index={device_index} | bound={bound_name} | current={cur}")
    
    results_csv = os.path.join(args.base_root, args.test_name, "results.csv")
    os.makedirs(os.path.dirname(results_csv), exist_ok=True)

    acc = main(args)

   
    file_exists = os.path.isfile(results_csv)
    with open(results_csv, "a", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        if not file_exists:
            w.writerow(["seed", "acc"] + PARAM_KEY_ORDER)
        row_params = dict(used_params); row_params["epochs"] = args.epochs
        w.writerow([seed, acc] + [row_params[k] for k in PARAM_KEY_ORDER])

    print(f"[Worker] Domain={domain_letter} finished. Acc={acc:.6f} -> {results_csv}")
    return acc


def aggregate_all_domains(base_root: str, seed: int):
    console = Console()
    rows = []
    for letter in DOMAIN_NAMES:
        exact = os.path.join(base_root, f"{letter}{seed}", "results.csv")
        candidates = [exact] if os.path.isfile(exact) else sorted(
            glob.glob(os.path.join(base_root, f"{letter}*", "results.csv"))
        )
        if not candidates:
            continue

        path = candidates[0]
        with open(path, newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            row = next(reader)
            try:
                s = int(row["seed"])
                acc = float(row["acc"])
            except Exception:
                dirname = os.path.basename(os.path.dirname(path))
                s = int(''.join(ch for ch in dirname if ch.isdigit()) or seed)
                acc = float(row.get("acc", list(row.values())[1]))
        rows.append((letter, s, acc))

    order = {l: i for i, l in enumerate(DOMAIN_NAMES)}
    rows.sort(key=lambda x: order.get(x[0], 999))

    summary_csv = os.path.join(base_root, f"summary_domains_seed{seed}.csv")
    with open(summary_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["domain", "seed", "acc"])
        for l, s, a in rows:
            w.writerow([l, s, a])

    table = Table(title=f"Summary across domains (seed={seed})", show_header=True, header_style="bold")
    table.add_column("Domain", justify="center")
    table.add_column("Seed", justify="right")
    table.add_column("Acc", justify="right")
    for l, s, a in rows:
        table.add_row(l, str(s), f"{a:.6f}")
    console.print(table)
    console.print(f"[Aggregate] Saved summary to: {summary_csv}")



def launch_four_domains(seed: int, base_root: str, gpu_ids: list):

    console = Console()

    try:
        total = torch.cuda.device_count()
    except Exception:
        total = 0
    console.print(f"[Launcher] torch.cuda.device_count() = {total}")

    if total == 0:
        console.print("[red]No GPU visible to this Python session. Falling back to CPU (slow).[/red]")


    preferred = list(gpu_ids) if gpu_ids else GPU_IDS_DEFAULT
    available = [i for i in preferred if i < total] if total > 0 else []
    if total > 0 and not available:

        available = list(range(min(4, total)))

    domain_indices = [0, 1, 2, 3]  
    procs = []

    def spawn(d_idx, device_index, desc):
        letter = DOMAIN_NAMES[d_idx]
        test_name = f"{letter}{seed}"
        log_dir = os.path.join(base_root, test_name)
        os.makedirs(log_dir, exist_ok=True)
        log_path = os.path.join(log_dir, "train.log")

        env = os.environ.copy()
        env["PYTHONUNBUFFERED"] = "1"
        cmd = [
            sys.executable, os.path.abspath(__file__),
            "--worker",
            "--seed", str(seed),
            "--base_root", base_root,
            "--test_index", str(d_idx),
            "--device_index", str(device_index),
        ]

        log_f = open(log_path, "w", buffering=1, encoding="utf-8")
        console.print(f"[Launcher] Domain={letter} -> {desc} | LOG: {log_path}")
        p = subprocess.Popen(cmd, env=env, stdout=log_f, stderr=subprocess.STDOUT)
        procs.append((p, log_f))

    if total >= 1:
        
        if len(available) == 0:
            available = [0]  
        for i, d_idx in enumerate(domain_indices):
            dev = available[i % len(available)]
            spawn(d_idx, dev, f"GPU {dev}")
        for p, log_f in procs:
            p.wait(); log_f.close()
    else:

        for d_idx in domain_indices:
            spawn(d_idx, -1, "CPU")
            p, log_f = procs[-1]
            p.wait(); log_f.close()

    aggregate_all_domains(base_root, seed=seed)

if __name__ == "__main__":
    parser = get_arg_parser()
    args = parser.parse_args()

    if args.worker:
        
        if torch.cuda.is_available() and args.device_index is not None and args.device_index >= 0:
            try:
                torch.cuda.set_device(args.device_index)
            except Exception as e:
                print(f"[Worker] set_device({args.device_index}) failed: {e}")
        run_one_domain(
            domain_index=args.test_index,
            seed=args.seed,
            base_root=args.base_root,
            device_index=args.device_index,
        )
    else:
        pref = parse_gpu_ids(args.gpu_ids) or GPU_IDS_DEFAULT
        print(f"args.gpu_ids: {pref}")
        launch_four_domains(seed=args.seed, base_root=args.base_root, gpu_ids=pref)

