import itertools
import pandas as pd
import numpy as np
from tqdm import tqdm
from typing import Tuple, Callable
import re
from copy import deepcopy


def simplex_grid_points(n, k):
    points = []
    # (a_1, ..., a_n) >= 0, sum = k
    for comb in itertools.combinations_with_replacement(range(n), k):
        counts = [0] * n
        for idx in comb:
            counts[idx] += 1
        # normalize
        coords = [c / k for c in counts]
        points.append(tuple(coords))
    return points


def constrained_simplex_grid_points(n, k, index, min_val=0.5, max_val=0.95,
                                    c_points: int = 4):
    points = simplex_grid_points(n - 1, k)
    points = np.asarray(points)
    add_points = np.linspace(min_val, max_val, c_points)
    points = points[None]
    # deep copy
    points = np.repeat(points, axis=0, repeats=add_points.shape[0]).copy()
    # (M, K)
    results = []
    for new_pt, point_set in zip(add_points, points):
        # (M, K)
        point_set: np.ndarray = point_set * (1 - new_pt)
        # cast to list
        point_set = point_set.tolist()
        for p in point_set:
            if index >= 0:
                p.insert(index, new_pt)
            elif index == -1:
                p.append(new_pt)
            else:
                p.insert(index + len(p), new_pt)

            results.append(tuple(deepcopy(p)))

    return results


def iter_once(test_env, model, weight, num_rewards: int):
    obs, _ = test_env.reset()
    done = False
    score = np.zeros(shape=(num_rewards,))
    while not done:
        action, _ = model.predict(obs, state=weight) #, langevin=False)
        obs, reward, done, timeout, info = test_env.step(action)
        score += reward
        done = done or timeout
    return score


def vec_iter_once(test_env, model, weight, num_rewards: int, n_env: int):
    obs = test_env.reset()
    episode_done = np.zeros(shape=(n_env,), dtype=np.bool_)
    score = np.zeros(shape=(n_env, num_rewards,))
    vec_weight = weight[None]
    vec_weight = np.repeat(vec_weight, axis=0, repeats=n_env)

    while not episode_done.all():
        action, _ = model.predict(obs, state=vec_weight, langevin=False)
        obs, reward, done, info = test_env.step(action)
        score[np.where(~episode_done)] += reward[np.where(~episode_done)]

        if done.any():
            episode_done[np.where(done)] = True
    return score


def iter_finance_once(test_env, model, weight, num_rewards: int,
                      weight_id: int,
                      episode_id: int):
    obs, init_info = test_env.reset()
    done = False
    score = np.zeros(shape=(num_rewards,))
    init_info['episode_id'] = episode_id
    init_info['weight_id'] = weight_id
    infos = [deepcopy(init_info)]

    while not done:
        action, _ = model.predict(obs, state=weight, langevin=False)
        obs, reward, done, timeout, info = test_env.step(action)
        score += reward
        done = done or timeout
        info['episode_id'] = episode_id
        info['weight_id'] = weight_id
        infos.append(deepcopy(info))
    return score, infos


def test(test_env, model, num_grids: int, reward_dim: int, name: str,
         num_execution: int = 30, ) -> pd.DataFrame:
    weights = simplex_grid_points(reward_dim, num_grids)
    rows = []
    weight_id = 1
    for w in tqdm(weights):
        for execution in range(1, num_execution + 1):
            score = iter_once(test_env, model, np.array(w), reward_dim)
            row = {
                "name": name,
                "weight_id": weight_id,
                "execution": execution,
            }
            for i in range(reward_dim):
                row[f"weight_{i + 1}"] = w[i]
            for i in range(reward_dim):
                row[f"score_{i + 1}"] = score[i]
            rows.append(row)
        weight_id += 1
    return pd.DataFrame(rows)


def constrained_simplex_finance_test(test_env, model, index: int, num_grids: int, reward_dim: int, name: str,
                                     num_execution: int = 30, min_val=0.05, max_val=0.95,
                                     c_points: int = 8) -> Tuple[pd.DataFrame, pd.DataFrame]:
    weights = constrained_simplex_grid_points(reward_dim, num_grids, index=index, min_val=min_val,
                                              max_val=max_val, c_points=c_points)
    rows = []
    weight_id = 1
    details = []
    for w in tqdm(weights):
        for execution in range(1, num_execution + 1):
            score, infos = iter_finance_once(test_env, model, np.array(w), reward_dim, weight_id=weight_id,
                                             episode_id=execution)
            row = {
                "name": name,
                "weight_id": weight_id,
                "execution": execution,
            }
            for i in range(reward_dim):
                row[f"weight_{i + 1}"] = w[i]
            for i in range(reward_dim):
                row[f"score_{i + 1}"] = score[i]
            rows.append(row)
            details = details + infos
        weight_id += 1
    return pd.DataFrame(rows), pd.DataFrame(details)


def constrained_simplex_test(test_env, model, index: int, num_grids: int, reward_dim: int, name: str,
                             num_execution: int = 30, min_val=0.5, max_val=0.95,
                             c_points: int = 4) -> pd.DataFrame:
    weights = constrained_simplex_grid_points(reward_dim, num_grids, index=index, min_val=min_val,
                                              max_val=max_val, c_points=c_points)
    rows = []
    weight_id = 1
    for w in tqdm(weights):
        for execution in range(1, num_execution + 1):
            score = iter_once(test_env, model, np.array(w), reward_dim)
            row = {
                "name": name,
                "weight_id": weight_id,
                "execution": execution,
            }
            for i in range(reward_dim):
                row[f"weight_{i + 1}"] = w[i]
            for i in range(reward_dim):
                row[f"score_{i + 1}"] = score[i]

            rows.append(row)
        weight_id += 1
    return pd.DataFrame(rows)


def vec_constrained_simplex_test(test_env, model, index: int, num_grids: int, reward_dim: int, name: str,
                                 n_env: int,
                                 num_execution: int = 30, min_val=0.5, max_val=0.95,
                                 c_points: int = 8) -> pd.DataFrame:
    weights = constrained_simplex_grid_points(reward_dim, num_grids, index=index, min_val=min_val,
                                              max_val=max_val, c_points=c_points)
    rows = []
    weight_id = 1
    print("num_execution",  num_execution)

    for w in tqdm(weights):
        execution = 0
        while execution < num_execution:
            scores = vec_iter_once(test_env, model, np.array(w), reward_dim, n_env=n_env)
            for score in scores:
                row = {
                    "name": name,
                    "weight_id": weight_id,
                    "execution": execution,
                }
                for i in range(reward_dim):
                    row[f"weight_{i + 1}"] = w[i]
                for i in range(reward_dim):
                    row[f"score_{i + 1}"] = score[i]
                rows.append(row)
                execution += 1
                if execution >= num_execution:
                    break


        weight_id += 1
    return pd.DataFrame(rows)


def df_to_points(df: pd.DataFrame, aggregation: Callable | None = None):
    points = []
    weight_cols = [c for c in df.columns if re.fullmatch(r"weight_\d+", c)]
    score_cols = [c for c in df.columns if re.fullmatch(r"score_\d+", c)]
    if aggregation is None:
        aggregation = lambda x: x
    for weight_id, group in df.groupby("weight_id"):
        weight = group.iloc[0][weight_cols].to_numpy()
        scores = group[score_cols].to_numpy().tolist()
        points.append({
            "weight_id": int(weight_id),
            "weight": weight.astype(np.float128),
            "scores": aggregation(np.asarray(scores, dtype=np.float128))
        })
    return points


def pareto_mask_max(Y):
    M = Y[:, None, :] >= Y[None, :, :]
    S = Y[:, None, :] > Y[None, :, :]
    dominates = np.all(M, axis=2) & np.any(S, axis=2)
    dominated = np.any(dominates, axis=0)
    return ~dominated


def _hv_nd(Y, ref):
    if Y.size == 0:
        return 0.0
    Y = Y[np.all(Y > ref, axis=1)]
    if Y.size == 0:
        return 0.0
    Y = Y[pareto_mask_max(Y)]
    m = Y.shape[1]
    if m == 1:
        return np.maximum(0.0, np.max(Y[:, 0]) - ref[0])
    u = np.unique(Y[:, 0])
    u = u[u > ref[0]]
    u.sort()
    hv = 0.0
    prev = ref[0]
    for t in u:
        w = t - prev
        if w > 0:
            S = Y[Y[:, 0] >= t][:, 1:]
            hv += w * _hv_nd(S, ref[1:])
        prev = t
    return hv


def hypervolume(points, reference_point):
    Y = np.asarray(points, dtype=np.float128)

    r = np.asarray(reference_point, dtype=np.float128)
    return _hv_nd(Y, r)


def hypervolume_weighted_axes(points, reference_point, weights):
    Y = np.asarray(points, dtype=np.float128)
    r = np.asarray(reference_point, dtype=np.float128)
    w = np.asarray(weights, dtype=np.float128)
    if w.ndim != 1 or w.shape[0] != Y.shape[1]:
        raise ValueError("weights shape mismatch")
    if np.any(w <= 0):
        raise ValueError("weights must be positive")
    Yw = Y * w
    rw = r * w
    return _hv_nd(Yw, rw)


def hypervolume_for_preferences(points, preferences, reference_point: np.ndarray | None = None):
    if reference_point is None:
        reference_point = np.zeros_like(preferences[0])
    Y = np.asarray(points, dtype=np.float128)
    r = np.asarray(reference_point, dtype=np.float128)
    P = np.asarray(preferences, dtype=np.float128)

    hv_list = np.array([_hv_nd(Y * w, r * w) for w in P])
    return hv_list


if __name__ == '__main__':
    from pprint import pprint

    generated = constrained_simplex_grid_points(5, 4, index=-1, c_points=3)
    pprint(generated)
    print(len(generated))
