from time import time
from random import seed
import sys
sys.path.append(".")
from pickle import dump, load
import os
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
import os
import random
from random import randint
import numpy as np
import pandas as pd
import torch
from glob import glob
from heapq import heapify, heappush, heappop

def edit_primitives(var_names):
    assert len(var_names) > 0
    path = "htssr/primitives.py"
    special_symbol = var_names[0]
    with open(path, "r") as f:
        lines = f.read().split("\n")
    for special_line in range(len(lines)):
        if "special_symbol = " in lines[special_line]:
            lines[special_line] = f'special_symbol = "{special_symbol}"'
            break
    for _start in range(len(lines)):
        if "_variables" in lines[_start]:
            break
    for _end in range(_start, len(lines)):
        if "]" in lines[_end]:
            break
    new_vars = "\n    ".join(
        [
            f'("{_var}", _identity, lambda: "{_var}", _identity),'
            for _var in var_names
        ]
    )
    replacement = f"_variables = [\n    {new_vars}\n]"
    with open(path, "w") as f:
        before = f'{"\n".join(lines[:_start])}\n'
        after = f'\n{"\n".join(lines[(_end + 1):])}'
        f.write(before)
        f.write(replacement)
        f.write(after)

# edit_primitives(cnames)

import numpy as np
import torch
from torch.optim import Adam, SGD
from sklearn.metrics import classification_report
import pandas as pd
from random import seed
from time import time, ctime
from pickle import load
from multiprocessing import Process, shared_memory, Lock
from multiprocessing import Array, Value
from multiprocessing.managers import SharedMemoryManager
import numpy as np
import pickle

max_vars = 4
ids_size = 32
max_pos = 8
heldout_start_pos = max_pos - 3 # max(max_pos - 3, 5)
heldout_end_pos = max_pos + 4 # (max_pos + max_pos // 2)

hyper_params = []
eval_kwargs = {
    "scale_x": False,
    "scale_y": False,
}

# def make_canon_domain(col_names):
#     np.random.seed(1968)
#     noises = np.random.uniform(low=-1.0, high=1.0, size=(len(col_names) - 1, 6))
#     np.random.seed(int(time()))
#     domain = np.linspace(-20.0, 20.0, 1000)[[10, 20, 30, -30, -20, -10]]
#     domain_dic = {
#         col_names[0]: domain,
#     }
#     for pos, col_name in enumerate(col_names[1:]):
#         domain_dic[col_name] = domain + noises[pos]
#     return domain_dic

def make_canon_domain(col_names, max_vars):
    col_names = ["x", "y", "z", "w", "a", "b", "c", "d", "e", "f"]
    np.random.seed(1968)
    noises = np.random.uniform(low=-1.0, high=1.0, size=(max_vars, 6))
    np.random.seed(int(1e6 * time()) % (1 << 31))
    _domain = np.linspace(-20.0, 20.0, 1000)[[10, 20, 30, -30, -20, -10]]
    domain = {
        "x": _domain,
    }
    for pos in range(1, max_vars):
        domain[col_names[pos]] = _domain + noises[pos - 1]
    return domain

def make_feature_domain(df, y, nsamples=1000):
    domain_dic = df.copy()
    domain_dic["___y"] = y
    domain_dic = domain_dic.sample(
        nsamples,
        replace=False,
        axis="index",
    )
    sampled_y = domain_dic["___y"].to_numpy()
    del domain_dic["___y"]
    sorts = [np.argsort(domain_dic[_var]).tolist() for _var in domain_dic.columns]
    domain_dic = domain_dic.to_dict(orient="list")
    for _key in domain_dic.keys():
        domain_dic[_key] = np.array(domain_dic[_key])
    return domain_dic, sampled_y, sorts

def make_eval_domain(df):
    domain_dic = df.to_dict(orient="list")
    for _key in domain_dic.keys():
        domain_dic[_key] = np.array(domain_dic[_key])
    return domain_dic

# Mixin classes should always be on the left-hand side for a correct MRO
class DummyRegressor(RegressorMixin, BaseEstimator):
    def __init__(self, *, param=1):
        self.param = param
        self.expression = "x"

    def fit(self, X, y=None):
        col_names = list(X.columns)
        if len(col_names) > 4:
            return None
        cnames = ["x", "y", "z", "w", "a", "b", "c", "d", "e", "f"][:len(col_names)]
        edit_primitives(cnames)

        from htssr.bench.feynman import problems
        from htssr.primitives import special_symbol
        from htssr.search import search_prefix, make_learned_heuristic
        from htssr.nnet import (
            AllPairsSymNumNet,
            AllPairsTreeNet,
            AllPairsConvTreeNet,
            AllPairsCatTreeNet,
            AllPairsSingleDomainTreeNet,
            AllPairsMultiDomainTreeNet,
            AllPairsSimpresNet,
        )
        from htssr.cost import all_pairs_venn_loss
        from htssr.sampling import make_exceptions
        from htssr.batching import (
            make_heldout_batch,
            make_all_pairs_free_batch,
            free_batch_from_samples,
            make_free_heldout_batch,
        )
        from htssr.grammar import expansions
        from htssr.search import search_prefix, make_learned_heuristic
        from htssr.sampling import generate_all_pairs_free_samples
        from htssr.grammar import expansions
        from htssr.utils import (
            unroll,
            str2ids,
            to_infix,
            to_infix_params,
            expand_rules,
            fast_eval_expr,
            rolling_fast_eval_expr,
            rolling_tree_eval_expr,
        )
        from htssr.grammar import expansions
        from htssr.primitives import variables
        from htssr.canon import size_enumeration

        col_renamer = {}
        for canon_name, true_name in zip(cnames, col_names):
            col_renamer[('c', canon_name)] = true_name
            col_renamer[('t', true_name)] = canon_name
        self.col_renamer = col_renamer
        max_vars = len(col_names)
        # edit_primitives(col_names)
        # edit_primitives(cnames)
        # domain = make_canon_domain(col_names, max_vars)
        domain = make_canon_domain(cnames, max_vars)
        ####### The canon dataset can be used without worrying mapping the variables names
        # with open(data_file, "rb") as f:
        #     enum_expr, key_expr, ptrs = load(f)
        # canon_size = 9 if len(col_names) < 3 else 8
        # canon_size = 8 if len(col_names) < 3 else 7
        def load_canon(max_size, domain, overwrite=False):
            nvar = len(variables)
            path = f"canon_{max_size}_{nvar}.pkl"
            enum_expr, key_expr, ptrs = size_enumeration(max_size, domain)
            # if overwrite or not os.path.isfile(path):
            #     with open(path, "wb") as f:
            #         enum_expr, key_expr, ptrs = size_enumeration(max_size, domain)
            #         dump((enum_expr, key_expr, ptrs), f)
            # else:
            #     with open(path, "rb") as f:
            #         enum_expr, key_expr, ptrs = load(f)
            return enum_expr, key_expr, ptrs
        enum_expr, key_expr, ptrs = load_canon(7, domain, overwrite=False)
        print("Canon size", len(enum_expr))
        n_data_samples = min(1000, len(X))
        ####### TODO: get sorts
        _feat_domain, secret_vals, feat_sorts = make_feature_domain(X, y, nsamples=n_data_samples)
        feat_domain = {}
        for _var in _feat_domain:
            feat_domain[col_renamer[('t', _var)]] = _feat_domain[_var]

        device = "cuda:0" if torch.cuda.is_available() else "cpu"

        # model = AllPairsSingleDomainTreeNet(
        #     num_in_size=n_data_samples,
        #     num_girth=2048,
        #     num_digits=67,
        #     num_dbase=2.0,
        #     sym_dmodel=1024,
        #     sym_nhead=4,
        #     sym_num_layers=4,
        #     clf_agg="mean",
        #     use_diffs=False,
        #     sorts=feat_sorts,
        #     device=device,
        # ).to(device)

        model = AllPairsSingleDomainTreeNet(
            num_in_size=n_data_samples,
            num_girth=1024,
            num_digits=67,
            num_dbase=2.0,
            sym_dmodel=768,
            sym_nhead=4,
            sym_num_layers=4,
            clf_agg="mean",
            use_diffs=False,
            sorts=feat_sorts,
            device=device,
        ).to(device)

        # model = AllPairsSingleDomainTreeNet(
        #     num_in_size=n_data_samples,
        #     num_girth=512,
        #     num_digits=67,
        #     num_dbase=2.0,
        #     sym_dmodel=512,
        #     sym_nhead=4,
        #     sym_num_layers=8,
        #     clf_agg="mean",
        #     use_diffs=False,
        #     sorts=feat_sorts,
        #     device=device,
        # ).to(device)

        optim = Adam(model.parameters(), lr=1e-5)

        if len(X) < 1000:
            _max_len = 8
            _max_visited = 102400
            _beam_w = 16384
            _search_every = [599, 999]
        else:
            _max_len = 10
            _max_visited = 102400
            _beam_w = 8192
            _search_every = [99, 199, 399, 599, 799]
        
        n_epochs = 2000
        epoch_len = 50
        acc_steps = 1
        epoch_examples = ids_size * epoch_len
        all_loss = []
        sol, params = None, None
        # nickname = f"single_scalability_domains{n_domains}_s{max_vars}_feynman"
        # log_path = f"SymNumNet/training_log_{nickname}.txt"
        # model_path = f"SymNumNet/bb/model_{nickname}.pth"
        # search_log_path = f"SymNumNet/search_log_{nickname}.txt"

        for epoch in range(n_epochs):
            model.train()
            optim.zero_grad()
            t0 = time()
            for iteration in range(epoch_len):
                batch_ids, batch_parents, batch_vals, batch_y = make_all_pairs_free_batch(
                    set(), # heldout_ids,
                    ids_size,
                    feat_domain,
                    min_pos=1,
                    max_pos=heldout_end_pos,
                    root_src=False,
                    noisy_fast=False,
                    expansions=expansions,
                    device=device,
                    force_padding=heldout_end_pos,
                    test_canon=False,
                    domain=domain,
                    key_expr=key_expr,
                )
                y_pred = model(
                    batch_vals,
                    expr_ids=batch_ids,
                    parent_ids=batch_parents,
                    # domain_ids=torch.tensor(curr_domain_pos).long().to(device),
                    # sorts=curr_sorts,
                )
                loss = all_pairs_venn_loss(y_pred, batch_y)
                scaled_loss = loss / acc_steps
                scaled_loss.backward()
                if iteration % acc_steps == (acc_steps - 1):
                    optim.step()
                    optim.zero_grad()
                all_loss.append(loss.item())
            # if epoch % 50 == 0:
            #     print(f"{epoch:04d}: {np.mean(all_loss[-epoch_len:]):.3f}")
            dt = time() - t0
            if (epoch % 1000) in _search_every:
                model.eval()
                heuristic, tc_secret_vals = make_learned_heuristic(
                    model,
                    secret_vals,
                    domain,
                    feat_domain,
                    device=device,
                    tree_vec=True,
                    force_padding=heldout_end_pos,
                )
                for _max_len in range(_max_len, _max_len + 1):
                    sol, vis, trace, params = search_prefix(
                        key_expr,
                        heuristic,
                        tc_secret_vals,
                        domain,
                        feat_domain,
                        eps=1e-8,
                        max_len=_max_len,
                        max_visited=_max_visited,
                        expansions=expansions,
                        expansion_cache=None,
                        levels=1,
                        penalty_factor=0.0,
                        prune_dist=1.1,
                        fit_params=True,
                        max_fit_iter=10,
                        n_param_inits=1,
                        fit_err_tol=1e-3,
                        h_err_tol=1e-3,
                        dropout=0,
                        test_canon=False,
                        beam_w=_beam_w,
                        topk_children=-1,
                        domain_id=None,
                        sorts=None,
                    )
                    if sol is not None:
                        break
                if sol is not None:
                    break

        ####### End of heuristic training

        ####### TODO: Make final Sympy-friendly expression (put params inside expression)
        self._sol_ids = sol
        self._params = params

        ####### Set fitted and expression
        self.is_fitted_ = True
        # if sol is not None:
        #     print(to_infix(sol))
        if sol is None:
            return None
        if params is not None:
            print(params)
            infix_sol = to_infix_params(sol, params)
        else:
            infix_sol = to_infix(sol)

        for _var in cnames:
            infix_sol = infix_sol.replace(_var, col_renamer[('c', _var)])

        import re
        def replace_square_function(text):
            while 'SQUARE(' in text:
                text = re.sub(r'SQUARE\(([^()]*)\)', r'(\1)**2', text)
            return text

        infix_sol = replace_square_function(infix_sol)

        print(infix_sol)
        print("Params: ", params)
        self.expression = infix_sol
        return self

    def predict(self, X):
        from htssr.utils import fast_eval_expr, tc_fast_eval_expr
        _feat_domain = make_eval_domain(X)
        feat_domain = {}
        for _var in _feat_domain:
            feat_domain[self.col_renamer[('t', _var)]] = _feat_domain[_var]

        if self._params is None:
            ans = fast_eval_expr(self._sol_ids, feat_domain)
        else:
            ans = tc_fast_eval_expr(
                self._sol_ids,
                feat_domain,
                self._params,
            )
            ans = ans.detach().numpy()
        return ans

est = DummyRegressor()

def model(est, X=None):
    return est.expression
