import torch
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import standard_t, standard_normal
import matplotlib as mpl
import argparse
mpl.rcParams["xtick.labelsize"] = 17
mpl.rcParams["ytick.labelsize"] = 17

import os.path
import sys
import inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
from utils.flows import Vanilla_Flow, TAF, mTAF

if torch.cuda.is_available():
    torch.device("cuda")
    device = "cuda"
else:
    torch.device("cpu")
    device = "cpu"

D = 16

parser = argparse.ArgumentParser()
parser.add_argument("--num_heavy", type=int, default=8)
parser.add_argument("--df", type=int, default=2)
parser.add_argument("--plot", type=str, default="mTAF")
args = parser.parse_args()


df = args.df
num_heavy = args.num_heavy
setting = "df" + str(df) + "h" + str(num_heavy)
type_plot = args.plot # available options: "mTAF" and "vanillaTAF"
num_light = 16 - num_heavy
target_marg = 16
num_samps = 1000

path_vanilla = "models/vanilla_" + setting
path_taf = "models/taf_" + setting
path_mtaf = "models/mtaf_" + setting

# 1. Load Models
######################
mtaf = mTAF(str(D), num_heavy, df)
path_tails = "models/tail_est/" + setting + "/tail_estimator" # these are just placeholders; ignore them
mtaf.config(path_tails)
mtaf.load_model(path=path_mtaf)

taf = TAF(str(D), num_heavy, df)
taf.load_model(path=path_taf)

vanilla = Vanilla_Flow(str(D), num_heavy, df)
vanilla.load_model(path=path_vanilla)

if type_plot=="mTAF":
    f, (ax_mtaf1, ax_mtaf2, ax_mtaf3, ax_mtaf4) = plt.subplots(1, 4, sharey=True, figsize=(16, 6))
    ax_mtaf1.set_ylim([0, 0.15])
    ax_mtaf2.set_ylim([0, 0.15])
    ax_mtaf3.set_ylim([0, 0.15])
    ax_mtaf4.set_ylim([0, 0.15])
else:
    f, (ax_vanilla, ax_taf) = plt.subplots(1, 2, sharey=True)

if type_plot=="mTAF":
    counter_setting = 0
    for setting in ["df2h1", "df2h2", "df2h4", "df2h8"]:
        counter_setting += 1
        num_heavy = int(setting[-1])
        num_light = 16 - num_heavy
        target_marg = 16

        path_mtaf = "models/mtaf_" + setting

        # sample data from distribution
        data_test = mtaf.data_test
        if counter_setting == 1:
            ax_mtaf1.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)
        elif counter_setting == 2:
            ax_mtaf2.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)
        elif counter_setting == 3:
            ax_mtaf3.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)
        elif counter_setting == 4:
            ax_mtaf4.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)

        # get the permutations:
        perm = np.load("models/mtaf_" + setting + "_ordering.npy")
        relevant_comp = int(np.argwhere(perm == target_marg - 1))

        # get the correct base marginal and sample from it
        for parameter in mtaf.base_dist.parameters():
            dfs = parameter.detach().cpu().numpy()
        try: # if the marginal is heavy-tailed
            print("The marginal is heavy-tailed")
            df = dfs[relevant_comp - num_light]
            samps = standard_t(df, [num_samps, 1])
        except: # if the marginal is light-tailed
            print("The marginal is light-tailed")
            samps = standard_normal([num_samps, 1])
        tail_event = samps[np.argpartition(-np.linalg.norm(samps, axis=1), 50)[:50]]

        base_samps = np.concatenate((np.zeros((50, relevant_comp)), tail_event, np.zeros((50, 16 - relevant_comp - 1))),
                                    axis=1)

        flow_tail = mtaf.sample_with(base_samps).detach().cpu().numpy()

        print("plotting...")
        if counter_setting==1:
            ax_mtaf1.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)
        elif counter_setting==2:
            ax_mtaf2.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)
        elif counter_setting==3:
            ax_mtaf3.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)
        elif counter_setting==4:
            ax_mtaf4.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)
    plt.savefig("mTAF_tails.pdf")

    # now demonstrate that other base tail samples do not produce tail samples in the specific target marginal
    plt.clf()
    f, (ax_mtaf1, ax_mtaf2, ax_mtaf3, ax_mtaf4) = plt.subplots(1, 4, sharey=True, figsize=(16, 6))
    counter_marg = 0
    # generate tail samples from the 3rd, 7th, 11th, and 15th marginal
    for j in [2, 6, 10, 14]:
        counter_marg += 1
        if relevant_comp==j: # make sure that these example marginals are not the relevant marginal
            j += 1
        try:
            df = dfs[j - num_light]
            samps = standard_t(df, [num_samps, 1])
        except:
            print("The marginal is light-tailed")
            print("Sampling normal samples...")
            samps = standard_normal([num_samps, 1])
        tail_event = samps[np.argpartition(-np.linalg.norm(samps, axis=1), 50)[:50]]

        base_samps = np.concatenate((np.zeros((50, j)), tail_event, np.zeros((50, 16 - j - 1))),
                                    axis=1)

        flow_tail = mtaf.sample_with(base_samps).detach().cpu().numpy()

        print("plotting...")
        if counter_marg== 1:
            ax_mtaf1.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)
            ax_mtaf1.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)
        elif counter_marg == 2:
            ax_mtaf2.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)
            ax_mtaf2.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)
        elif counter_marg == 3:
            ax_mtaf3.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)
            ax_mtaf3.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)
        elif counter_marg == 4:
            ax_mtaf4.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)
            ax_mtaf4.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)
    plt.savefig("mtaf_" + setting + "_othermarg.pdf")

else: # Plot Vanilla vs. TAF
    # get the permutations:
    perm_taf = np.load("models/taf_" + setting + "_ordering.npy")
    relevant_comp_taf = int(np.argwhere(perm_taf==target_marg - 1))
    perm_vanilla = np.load("models/vanilla_" + setting + "_ordering.npy")
    relevant_comp_vanilla = int(np.argwhere(perm_vanilla == target_marg - 1))

    # 2. Draw samples
    data_test = mtaf.data_test
    ax_vanilla.set_ylim([0, 0.15])
    ax_taf.set_ylim([0, 0.15])

    ax_vanilla.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)
    ax_vanilla.set_title("Vanilla")
    ax_taf.hist(data_test[:, target_marg - 1], range=(-10, 10), bins=30, density=True)
    ax_taf.set_title("TAF", fontsize=20)

    # 3. Sample Tail Events:
    num_samps = 1000

    ##### vanilla
    try:
        samps = standard_normal([num_samps, 1])
        tail_event = samps[np.argpartition(-np.linalg.norm(samps, axis=1), 50)[:50]]
        base_samps = np.concatenate(
            (np.zeros((50, relevant_comp_vanilla)), tail_event, np.zeros((50, 16 - relevant_comp_vanilla - 1))),
            axis=1)
        flow_tail = vanilla.sample_with(tail_event).detach().cpu().numpy()

        ax_vanilla.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)
    except:
        print("Vanilla failed to sample tail-events!")

    ##### taf
    for parameter in taf.base_dist.parameters():
        df = parameter.detach().cpu().numpy()

    samps = standard_t(df, [num_samps, 1])
    tail_event = samps[np.argpartition(-np.linalg.norm(samps, axis=1), 50)[:50]]

    base_samps = np.concatenate((np.zeros((50, relevant_comp_taf)), tail_event, np.zeros((50, 16 - relevant_comp_taf - 1))),
                                axis=1)

    flow_tail = taf.sample_with(base_samps).detach().cpu().numpy()
    ax_taf.hist(flow_tail[:, target_marg - 1], bins=10, density=True, alpha=0.5)

    plt.savefig("plots/vanillavstaf" + setting + ".pdf")



