import os
import random
from random import randint
import numpy as np
import pandas as pd
import torch

# Set all random seeds
def set_seed(seed=1968):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU
    # Additional settings for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # For some operations
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(1968)

max_vars = 3
n_data_samples = 1000
n_domains = 2 # 16 # 8 # 4

def get_domain(srbench_path, return_sorts=False):
    srbench_df = pd.read_csv(srbench_path, sep="\t")
    np.random.seed(1968)
    srbench_df = srbench_df.sample(n_data_samples, replace=False, axis="index")
    var_cols = [_col for _col in list(srbench_df.columns)[:max_vars] if _col != "target"]
    srbench_df = srbench_df[var_cols + ["target"]]
    col_vocab = ["x", "y", "z", "w"][:max_vars]
    nvars = len(srbench_df.columns) - 1
    srbench_df.columns = col_vocab[:nvars] + ["target"]
    for _col in col_vocab[nvars:]:
        srbench_df[_col] = 0.0
    sorts = None
    if return_sorts:
        srbench_df = srbench_df.sort_values("x", axis="index", ascending=True)
        sorts = [np.argsort(srbench_df[_var]).tolist() for _var in srbench_df.columns if _var != "target"]
    srbench_domain = srbench_df.to_dict(orient="list")
    for _col in srbench_df.columns:
        srbench_domain[_col] = np.array(srbench_domain[_col])
    return srbench_domain, sorts

def make_random_domain(nvars, nsamples=n_data_samples, low=0.0, high=10.0):
    varvocab = "xyzwtuv"[:nvars]
    _domain = {_var: np.random.uniform(low=low, high=high, size=nsamples) for _var in varvocab}
    return _domain

def pick_domain(domains):
    return domains[randint(0, len(domains) - 1)]

# ###### 16 domains with at least 3 variables, always include feynman_I_34_1 as the first domain

####### SRBench data
srbench_path = "../srbench/pmlb/datasets/feynman_I_34_1/feynman_I_34_1.tsv.gz"
srbench_domain, srbench_sorts = get_domain(srbench_path, return_sorts=True)

from glob import glob
from htssr.bench.feynman import problems

feynman_paths = glob("../srbench/pmlb/datasets/feynman_*/*.tsv.gz")
feynman_subset = []
feynman_sorts = []
feynman_dict = {}
feynman_domain_ids = {}
eval_domain_id = None
for _path in feynman_paths:
    title = _path.split("/")[-1].split(".")[0]
    if title not in problems:
        continue
    if len(problems[title]["prefix"]) == 0:
        continue
    feynamn_df = pd.read_csv(_path, sep="\t")
    # if len(feynamn_df.columns) <= (max_vars + 1):
    if len(feynamn_df.columns) == (max_vars + 1):
        _domain, _sorts = get_domain(_path, return_sorts=True)
        feynman_subset.append(_domain)
        feynman_sorts.append(_sorts)
        if _path == srbench_path:
            eval_domain_id = len(feynman_subset) - 1
        feynman_dict[title] = _domain
        feynman_domain_ids[title] = len(feynman_subset) - 1

assert len(feynman_subset) >= 16
feynman_subset[eval_domain_id], feynman_subset[0] = feynman_subset[0], feynman_subset[eval_domain_id]
feynman_sorts[eval_domain_id], feynman_sorts[0] = feynman_sorts[0], feynman_sorts[eval_domain_id]
eval_domain_id = 0

from htssr.nnet import (
    AllPairsSymNumNet,
    AllPairsTreeNet,
    AllPairsConvTreeNet,
    AllPairsCatTreeNet,
    AllPairsSingleDomainTreeNet,
    AllPairsMultiDomainTreeNet,
)
from htssr.cost import all_pairs_venn_loss
from htssr.sampling import make_exceptions
from htssr.batching import (
    make_heldout_batch,
    make_all_pairs_free_batch,
    free_batch_from_samples,
    make_free_heldout_batch,
)
from htssr.grammar import expansions
from htssr.search import search_prefix, make_learned_heuristic

import numpy as np
import torch
from torch.optim import Adam, SGD
from sklearn.metrics import classification_report
import pandas as pd
from random import seed
from time import time, ctime

device = "cuda:0" if torch.cuda.is_available() else "cpu"

np.random.seed(1968)
noises = np.random.uniform(low=-1.0, high=1.0, size=(3, 6))
np.random.seed(int(1e6 * time()) % (1 << 31))

domain = np.linspace(-20.0, 20.0, 1000)[[10, 20, 30, -30, -20, -10]]
domain = {
    "x": domain,
    **{
        ["y", "z", "w"][vid]: (domain + noises[vid])
        for vid in range(max_vars - 1)
    }
    # "y": domain + noises[0],
    # "z": domain + noises[1],
    # "w": domain + noises[2],
}

from htssr.canon import size_enumeration
max_domain_size = {
    1: 9,
    2: 9,
    3: 8,
    4: 8,
}[max_vars]
enum_expr, key_expr, ptrs = size_enumeration(max_domain_size, domain)

# np.random.seed(1968)
# feat_noises = np.random.uniform(low=-1.0, high=1.0, size=(3, 1000))
# np.random.seed(int(1e6 * time()) % (1 << 31))

feat_domain = srbench_domain
feat_sorts = torch.tensor(srbench_sorts).long().to(device)

from multiprocessing import Process, shared_memory, Lock
from multiprocessing import Array, Value
from multiprocessing.managers import SharedMemoryManager
import numpy as np
import pickle
import time

from htssr.sampling import generate_all_pairs_free_samples
### TODO: use expand_rules(expansions, ..., level=4) instead of expansions
from htssr.grammar import expansions
from htssr.utils import (
    unroll,
    to_infix,
    expand_rules,
    fast_eval_expr,
    rolling_fast_eval_expr,
)
from htssr.grammar import expansions

dummy = False
noisy_fast = False
ids_size = 32
max_pos = 8
heldout_start_pos = max_pos - 3 # max(max_pos - 3, 5)
heldout_end_pos = max_pos + 4 # (max_pos + max_pos // 2)
# heldout_end_pos = max_pos + 6 # (max_pos + max_pos // 2)

expanded_expansions = expand_rules(
    key_expr,
    expansions,
    domain,
    max_pos,
    levels=1,
    test_canon=False,
)

max_buffer_size = 300000 # 282450

heldout_path = "expr_dataset/heldout_feynman_mid_compl_rules_xyz10.pkl"
with open(heldout_path, "rb") as f:
    heldout_dict = pickle.load(f)
heldout_ids = heldout_dict[max_vars]
heldout_ids = set([_ids for _ids in heldout_ids])
heldout_rolls = [[list(_ids)] for _ids in heldout_ids]

# heldout_ids, heldout_rolls = make_exceptions(
#     enum_expr,
#     key_expr,
#     ptrs,
#     domain,
#     nexc=128,
#     min_pos=heldout_start_pos,
#     max_pos=heldout_end_pos,
#     # expansions=expansions,
#     expansions=expanded_expansions,
#     # test_canon=False,
#     test_canon=True,
# )

extrapolation_mask = [len(roll) > max_pos for roll in unroll(heldout_rolls)]

heldout_batch_ids, heldout_parent_ids, heldout_batch_vals, heldout_batch_y = make_free_heldout_batch(
    heldout_rolls,
    feat_domain,
    force_padding=heldout_end_pos,
    device=device,
)

model = AllPairsMultiDomainTreeNet(
    num_in_size=n_data_samples,
    num_girth=2048,
    num_digits=67,
    num_dbase=2.0,
    sym_dmodel=1024,
    sym_nhead=4,
    sym_num_layers=4,
    clf_agg="mean",
    num_sorts=max_vars,
    device=device,
).to(device)

optim = Adam(model.parameters(), lr=1e-5)

n_epochs = 1000
epoch_len = 50
acc_steps = 1
assert epoch_len % acc_steps == 0
epoch_examples = ids_size * epoch_len
all_loss = []
all_mce = []
load_state = False
nickname = f"single_scalability_domains{n_domains}_s{max_vars}_fixing"
log_path = f"SymNumNet/training_log_{nickname}.txt"
model_path = f"SymNumNet/bb/model_{nickname}.pth"
search_log_path = f"SymNumNet/search_log_{nickname}.txt"

if load_state:
    saved_state = torch.load(model_path, weights_only=True)
    model.load_state_dict(saved_state["model"])
    optim.load_state_dict(saved_state["optim"])
    prev_epochs = saved_state["epoch"]
else:
    with open(log_path, "w") as f:
        f.write("")
    with open(search_log_path, "w") as f:
        f.write("")
    prev_epochs = len(all_loss) // epoch_len

# ###### Start of training

for epoch in range(prev_epochs, prev_epochs + n_epochs):
    model.train()
    optim.zero_grad()
    t0 = time.time()
    for iteration in range(epoch_len):
        curr_domain_pos = randint(0, n_domains - 1)
        curr_feat_domain = feynman_subset[curr_domain_pos]
        curr_sorts = feynman_sorts[curr_domain_pos]
        batch_ids, batch_parents, batch_vals, batch_y = make_all_pairs_free_batch(
            heldout_ids,
            ids_size,
            curr_feat_domain,
            min_pos=1,
            max_pos=heldout_end_pos,
            root_src=False,
            noisy_fast=False,
            expansions=expanded_expansions,
            device=device,
            force_padding=heldout_end_pos,
            test_canon=False,
            # test_canon=True,
            domain=domain,
            key_expr=key_expr,
        )
        y_pred = model(
            batch_vals,
            expr_ids=batch_ids,
            parent_ids=batch_parents,
            domain_ids=torch.tensor(curr_domain_pos).long().to(device),
            sorts=curr_sorts,
        )
        loss = all_pairs_venn_loss(y_pred, batch_y)
        scaled_loss = loss / acc_steps
        scaled_loss.backward()
        if iteration % acc_steps == (acc_steps - 1):
            optim.step()
            optim.zero_grad()
        all_loss.append(loss.item())
    dt = time.time() - t0
    if epoch % 25 == 0:
        epoch_loss = np.mean(all_loss[-epoch_len:])
        log_str = f"[{ctime()}][Epoch {epoch}] Loss: {epoch_loss:.3f}"
        # Regular evaluation
        cls_cols = ["0", "1"]
        with torch.no_grad():
            heldout_y_pred = model.predict(
                heldout_batch_vals,
                expr_ids=heldout_batch_ids,
                parent_ids=heldout_parent_ids,
                domain_ids=torch.tensor(eval_domain_id).long().to(device),
                sorts=feynman_sorts[eval_domain_id],
            )
        heldout_y_pred.fill_diagonal_(1)
        flat_heldout_y_pred = (
            heldout_y_pred[extrapolation_mask]
            .reshape(-1)
            .cpu()
            .detach()
            .numpy()
        )
        flat_heldout_batch_y = (
            heldout_batch_y[extrapolation_mask]
            .reshape(-1)
            .cpu()
            .detach()
            .numpy()
        )
        clf_report = classification_report(
            flat_heldout_batch_y,
            flat_heldout_y_pred,
            output_dict=True,
        )
        clf_report = pd.DataFrame(clf_report)
        clf_report = str(clf_report[cls_cols].loc["f1-score"])
        clf_report = (
            clf_report
            .replace("\n", " | ")
            .replace("Name: f1-score, dtype: float64", "")
        )
        with open(log_path, "a") as f:
            f.write(f"{log_str} ; f1-score: | {clf_report} Eval Infer Time: {dt:.4f} (s)\n")
    # if epoch % 1000 == 0:
    #     torch.save(
    #         {
    #             "model": model.state_dict(),
    #             "optim": optim.state_dict(),
    #             "epoch": epoch,
    #         },
    #         model_path,
    #     )

####### Run search
model.eval()
solved = []
_heldout_ids = [list(_ids) for _ids in heldout_ids]
num_problems = len(_heldout_ids)
for prob_pos, _ids in enumerate(_heldout_ids):
    # _secret_vals = rolling_fast_eval_expr(_ids, feat_domain)
    _secret_vals = fast_eval_expr(_ids, feat_domain)
    heuristic, tc_secret_vals = make_learned_heuristic(
        model,
        _secret_vals,
        domain,
        feat_domain,
        device=device,
        tree_vec=True,
        force_padding=heldout_end_pos,
    )
    sol, vis, trace, params = search_prefix(
        key_expr,
        heuristic,
        tc_secret_vals,
        domain,
        feat_domain,
        eps=1e-8,
        max_len=len(_ids),
        max_visited=10240,
        expansions=expansions,
        expansion_cache=None,
        levels=1,
        penalty_factor=0.0,
        prune_dist=1.1,
        # debug_secret=_secrets[-1],
        fit_params=True,
        max_fit_iter=15,
        n_param_inits=1,
        fit_err_tol=1e-9,
        dropout=0,
        test_canon=False,
        beam_w=128,
        topk_children=-1,
        domain_id=torch.tensor(eval_domain_id).long().to(device),
        sorts=feynman_sorts[eval_domain_id],
    )
    if sol is not None:
        solved.append(_ids)
        total_solved = len(solved)
        _header = f"[{total_solved}/{(prob_pos + 1)} solved][{ctime()}][Lenght: {len(_ids)}][Soln. Length: {len(sol)}]"
        with open(search_log_path, "a") as f:
            f.write(f"{_header}\n{to_infix(_ids)}\n<=>\n{to_infix(sol)}\n\n")
    else:
        total_solved = len(solved)
        _header = f"[{total_solved}/{(prob_pos + 1)} solved][{ctime()}][Lenght: {len(_ids)}]"
        with open(search_log_path, "a") as f:
            f.write(f"{_header}\n{to_infix(_ids)}\nSoln. not found.\n\n")
