import json
import os
import random
from sys import argv
import pandas as pd
import numpy as np
import rbo as rbo_lib
import re

from tqdm import tqdm

from sharpe_optimization2 import get_optimal_weights


# Pool Parameters
MAX_SIZE_SUBPOOLS = 500
N_SUBPOOLS = 1
RANDOM_SUBPOOLS = True

RATE_FREE = 0  # 4.26 / 100
N_TRADING_DAYS = 252
N_YEARLY_TRADES = N_TRADING_DAYS * 72   # We are at the scale of 5 minutes

TOP_K = 1
ONLY_POS = True
ABSOLUTE_VALUE = False
NO_WEIGHTS = False
NORMALIZE_WEIGHTS = False
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
BASE = argv[1]
COVS = argv[2] or None
PRINT_RBO = True
PRINT_MR = True
PRINT_RISK = True
PRINT_SHARPE = True
PRINT_RMSE = True
PRINT_MAPE = True
PRINT_DIST = True
RE_HORIZON = re.compile("horizon(?P<horizon>[0-9]*)")


def get_pools(results, max_size_subpools, dir):
    names = sorted(list(set(results["Name"])))
    random.shuffle(names)
    res = []
    if len(names) < max_size_subpools:
        return [names]
    if RANDOM_SUBPOOLS:
        for i in range(N_SUBPOOLS):
            res.append(random.sample(names, k=max_size_subpools))
    else:
        for i in range(0, len(names), max_size_subpools):
            res.append(names[i: i + max_size_subpools])
    return res
    #all_pools = list(itertools.combinations(names, r=max_size_subpools))
    #return random.sample(all_pools, k=min(len(all_pools), N_SUBPOOLS))


def get_pools_2(results, max_size_subpools, dir):
    pool_path = dir + "/" + "pool_" + str(max_size_subpools) + "_" + str(SEED) + ".json"
    if not os.path.exists(pool_path):
        names = sorted(list(set(results["Name"])))
        random.shuffle(names)
        res = []
        for i in range(0, len(names), max_size_subpools):
            res.append(names[i: i + max_size_subpools])
        with open(pool_path, "w") as f:
            json.dump(res, f)
    with open(pool_path) as f:
        pools = json.load(f)
    return pools


def get_rates(column="13 Wk Bank Discount Rate"):
    df = pd.read_csv("USTREASURY-BILLRATES.csv")
    return {x["Date"]: x[column] / 100 for _, x in df.iterrows()}


def get_optimal_return_pool(results, pool, covariances):
    mask = [x in pool for x in results["Name"]]
    subpool = results[mask].copy().sort_values(by="Name")
    cov_names = [x.replace(".csv", "") for x in subpool["Name"]]
    covariances = covariances.loc[cov_names, cov_names]
    pred_returns = subpool["Pred"]
    # assert(list(covariances.columns) == [x.replace(".csv", "") for x in subpool["Name"]])
    # assert (list(covariances.columns) == [x for x in covariances.index])
    optimal_weights = get_optimal_weights(pred_returns / HORIZON, covariances)
    if NORMALIZE_WEIGHTS:
        pos_total_weight = optimal_weights[optimal_weights > 0].sum()
        if pos_total_weight == 0:
            pos_total_weight = 1
        neg_total_weight = optimal_weights[optimal_weights < 0].sum()
        if neg_total_weight == 0:
            neg_total_weight = 1
        optimal_weights[optimal_weights > 0] /= pos_total_weight
        optimal_weights[optimal_weights < 0] /= neg_total_weight
    # print(optimal_weights)
    true_returns = subpool["True"].values.dot(optimal_weights)
    return true_returns


def get_return_pool(results, pool, top_k):
    mask = [x in pool for x in results["Name"]]
    subpool = results[mask].copy().sample(frac=1)
    subpool["abs"] = subpool["Pred"].abs()
    # Only positive weights
    subpool["Pred"] += abs(subpool["Pred"].min()) * 2
    if ABSOLUTE_VALUE:
        subpool = subpool.sort_values(by="abs")
    else:
        subpool = subpool.sort_values(by="Pred")
    positive = subpool.head(top_k)
    negative = subpool.tail(top_k)
    total_weights_pos = positive["Pred"].abs().sum()
    total_weights_neg = negative["Pred"].abs().sum()
    norm = len(positive)
    if ONLY_POS:
        total_weights = positive["Pred"].abs().sum()
    else:
        total_weights = positive["Pred"].abs().sum() + negative["Pred"].abs().sum()
    if total_weights == 0 or NO_WEIGHTS:
        norm_pos_weights = pd.Series([1.0 / norm for _ in range(norm)])
        norm_neg_weights = pd.Series([1.0 / norm for _ in range(norm)])
    else:
        # We normalize the positives and the negatives to sum to one
        norm_pos_weights = positive["Pred"] / total_weights_pos
        norm_neg_weights = negative["Pred"] / total_weights_neg
    # We make the negative really negative
    norm_neg_weights -= 1
    #print("Positive:", list(norm_pos_weights))
    #print("Negative:", list(norm_neg_weights))
    return_pos = norm_pos_weights * positive["True"]
    return_neg = norm_neg_weights * negative["True"]
    if ONLY_POS:
        return return_pos.sum()
    else:
        return return_pos.sum() + return_neg.sum()


def compute_metrics(returns, predictions, goldstandard, rbo):
    true_returns = compute_true_returns(returns)
    mean_return = np.mean(returns) * N_YEARLY_TRADES / len(returns)
    true_returns_mean = np.mean(true_returns) * N_YEARLY_TRADES
    risk = np.std(returns) * np.sqrt(N_YEARLY_TRADES) / len(returns)
    true_risk = np.std(true_returns) * np.sqrt(N_YEARLY_TRADES)
    sharpe = (mean_return - RATE_FREE) / risk
    true_sharpe = (true_returns_mean - RATE_FREE) / true_risk
    print("Number of predictions:", len(predictions))
    print("Number of returns:", len(returns))
    if PRINT_MR:
        print("Mean return/Final return:", "%.5f" % (mean_return * 100), "%")
        print("True Mean return:", "%.5f" % (true_returns_mean * 100), "%")
        # print(len(returns))
    if PRINT_RISK:
        print("Risk:", risk)
        print("True Risk", true_risk)
    if PRINT_SHARPE:
        print("Sharpe:", sharpe)
        print("True Sharpe:", true_sharpe, "+-(95%)", 1.96 * np.sqrt((1 + true_sharpe ** 2 / 2) / len(true_returns)))
    if PRINT_RBO:
        print("Mean RBO", np.mean(rbo))
    preds = np.concatenate(predictions)
    gs = np.concatenate(goldstandard)
    if PRINT_RMSE:
        print("RMSE", "%.5f" % np.sqrt(((preds - gs) ** 2).mean()))
    if PRINT_MAPE:
        print("MAPE", "%.5f" % (100 * np.abs((preds - gs) / (1 + gs)).mean()), "%")
    if PRINT_DIST:
        print("Dist LAST", np.sqrt(preds ** 2).mean())
    number_of_wins = ((preds > 0) & (gs > 0)).sum() + ((preds < 0) & (gs < 0)).sum()
    print("Number of wins: ", number_of_wins, "over", len(preds), number_of_wins / len(preds) * 100, "%")


def compute_true_returns(returns):
    n_returns = len(returns)
    true_returns = [(n_returns + returns[0]) / n_returns - 1]
    tmp_return = returns[0]
    for ret in returns[1:]:
        true_returns.append((n_returns + tmp_return + ret) / (n_returns + tmp_return) - 1)
        tmp_return += ret
    return true_returns


def get_rbo(results, pool):
    mask = [x in pool for x in results["Name"]]
    subpool = results[mask].copy()
    sort_true = subpool.sort_values(by="True")["Name"].values
    sort_pred = subpool.sort_values(by="Pred")["Name"].values
    return rbo_lib.RankingSimilarity(sort_true, sort_pred).rbo()


if __name__ == '__main__':
    global_pools = dict()
    res = dict()
    preds = dict()
    gs = dict()
    rbo = dict()
    first = True
    HORIZON = int(RE_HORIZON.search(BASE).group("horizon"))
    N_YEARLY_TRADES /= HORIZON
    # all_rates = get_rates()
    for dir in os.listdir(BASE):
        # One dir per model
        model_name = dir.replace("/", "")
        print("Evaluating model", model_name)
        dir = BASE + "/" + dir
        if not os.path.isdir(dir):
            continue
        filenames = os.listdir(dir)
        # Make sure we process everything chronologically
        filenames = sorted(filenames, key=lambda x: x.split(".")[0].split("-")[::-1])
        for filename in tqdm(os.listdir(dir)):
            if filename.endswith(".csv"):
                if model_name not in res:
                    res[model_name] = []
                    preds[model_name] = []
                    gs[model_name] = []
                    rbo[model_name] = []
                if first and "LAST" not in res:
                    res["LAST"] = []
                    preds["LAST"] = []
                    gs["LAST"] = []
                    rbo["LAST"] = []
                results = pd.read_csv(dir + "/" + filename)
                covs = pd.read_csv(COVS + "/" + filename, index_col="index")
                if len(results) == 0:
                    print("Empty for", dir, filename)
                    continue
                if False and model_name == "DARNN":
                    results["Pred"] = results["Pred"] - results["Pred"].mean()
                else:
                    results["Pred"] = results["Pred"] - 1
                results["True"] = results["True"] - 1
                results_last = pd.DataFrame({"Name": results["Name"], "True": results["True"]})
                results_last["Pred"] = [0 for _ in range(len(results_last))]
                preds[model_name].append(results["Pred"])
                gs[model_name].append(results["True"])
                if first:
                    preds["LAST"].append(results_last["Pred"])
                    gs["LAST"].append(results_last["True"])
                if filename not in global_pools:
                    global_pools[filename] = get_pools(results, MAX_SIZE_SUBPOOLS, dir)
                pools = global_pools[filename]
                returns_pools = []
                returns_pools_last = []
                for pool in pools:
                    #res[model_name].append(get_return_pool(results, pool, TOP_K))
                    returns_pools.append(get_optimal_return_pool(results, pool, covs))
                    rbo[model_name].append(get_rbo(results, pool))
                    if first:
                        #res["LAST"].append(get_return_pool(results_last, pool, TOP_K))
                        returns_pools_last.append(get_optimal_return_pool(results_last, pool, covs))
                        rbo["LAST"].append(get_rbo(results_last, pool))
                res[model_name].append(np.mean(returns_pools))
                if first:
                    res["LAST"].append(np.mean(returns_pools_last))
        first = False
    #models = sorted(res.keys())
    models = ["LAST", "scinet", "DARNN", "AR312", "AR1248", "AR2048"]
    for model in models:
        if model not in res:
            continue
        print("########## MODEL:", model)
        compute_metrics(res[model], preds[model], gs[model], rbo[model])

