import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import pickle
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm, trange
from scipy.spatial.distance import cdist

from model import C3

np.set_printoptions(suppress=True)
torch.set_printoptions(sci_mode=False)

# Setting up argparser
parser = argparse.ArgumentParser()
parser.add_argument("exp_dir", help="Directory to store results", type=str)
parser.add_argument("seeds", help="Range of seeds (separated by comma)", type=str)

args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED_RANGE = [int(s) for s in args.seeds.split(",")]
NUM_SEEDS = SEED_RANGE[1] - SEED_RANGE[0]
NUM_ENVS = 200
NUM_POINTS_PER_ENV = 100
CORRELATIONS = [-1., -0.6, -0.2, 0.2, 0.6, 1.]
NUM_ARMS = len(CORRELATIONS) + 1
VAL_RATIO = 0.1
SIGMA = 1

assert len(SEED_RANGE) == 2, "seeds must be in the form (lower_bound, upper_bound)"


class NoContextCorrArms:
    def __init__(self, anc_alpha, anc_beta, corrs, c=1):
        self.anc_alpha = anc_alpha
        self.anc_beta = anc_beta
        self.corrs = corrs
        self.c = c

        self.rng = None
        self.arm_probs = None

    def reset(self, seed=None):
        self.rng = np.random.RandomState(seed)
        rng = self.rng

        self.arm_probs = [rng.beta(self.anc_alpha, self.anc_beta)]  # first arm is always anchor
        for corr in self.corrs:
            self.arm_probs.append(rng.beta(*self.compute_corr_params(self.anc_alpha, self.anc_beta,
                                                                     self.c, self.arm_probs[0], corr)))

    def step(self, arm):
        return self.rng.binomial(1, self.arm_probs[arm])

    def get_pretraining_data(self, num_samples, seed=None):
        self.reset(seed=seed)
        rng = np.random.RandomState(seed)

        num_arms = len(self.arm_probs)
        all_arms = list()
        all_rewards = list()

        for i in range(num_samples):
            all_arms.append(rng.choice(num_arms))
            all_rewards.append(self.step(all_arms[-1]))

        return np.array(all_arms, dtype=int), np.array(all_rewards)

    def compute_corr_params(self, a, b, c, samples, corr):
        return c * (corr * (samples - 0.5) + 0.5) * (a + b), c * (corr * (0.5 - samples) + 0.5) * (a + b)


bandit = NoContextCorrArms(1, 1, np.array(CORRELATIONS), c=50)

layer_nums = [NUM_ARMS, 256, 2]
loss_coef = {"bce": 1, "ece": 5}

all_distances = np.zeros((NUM_SEEDS, NUM_ARMS, NUM_ARMS))

for i, seed in tqdm(enumerate(np.arange(*SEED_RANGE)), total=NUM_SEEDS):
    np.random.seed(seed)
    torch.manual_seed(seed)

    X_list = list()
    y_list = list()

    X_val_list = list()
    y_val_list = list()

    for e in range(NUM_ENVS):
        bandit.reset(seed=e * seed)
        A_data, R_data = bandit.get_pretraining_data(NUM_POINTS_PER_ENV, seed=e)
        X_list.append(torch.Tensor(np.eye(NUM_ARMS)[A_data]))
        y_list.append(torch.Tensor(R_data).unsqueeze(dim=-1))

    for e in range(10):
        bandit.reset(seed=(NUM_ENVS + e) * seed)
        A_val_data, R_val_data = bandit.get_pretraining_data(NUM_POINTS_PER_ENV, seed=NUM_ENVS + e)
        X_val_list.append(torch.Tensor(np.eye(NUM_ARMS)[A_val_data]))
        y_val_list.append(torch.Tensor(R_val_data).unsqueeze(dim=-1))

    model = C3(layer_nums, sigma=SIGMA, X_init=None, y_init=None, weight_factor=1.)

    model.fit_multi(X_list, y_list, None, DEVICE, loss_coef, X_val_list, y_val_list, None,
                    epochs=4, lr=1e-3, base_ratio=0.2, usage_ratio=0.5,
                    val_base_ratio=0.8, val_seed=seed * seed, pick_best_val=False,
                    plot=False, tqdm_pbar=False)

    with torch.no_grad():
        emb_vec = model.project(torch.eye(NUM_ARMS)).numpy()
    distances = cdist(emb_vec, emb_vec)

    all_distances[i] = distances

mean_distances = all_distances.mean(axis=0)
std_distances = all_distances.std(axis=0)

rbf_weights = np.exp(-mean_distances[0, 1:] / (2 * SIGMA ** 2))
rbf_lower = np.exp(-(mean_distances[0, 1:] - 1.96 * std_distances[0, 1:]) / (2 * SIGMA ** 2))
rbf_upper = np.exp(-(mean_distances[0, 1:] + 1.96 * std_distances[0, 1:]) / (2 * SIGMA ** 2))

coef_d = np.polyfit(bandit.corrs, mean_distances[0, 1:], deg=1)
coef_w = np.polyfit(bandit.corrs, rbf_weights, deg=1)

plt.figure(figsize=(4, 4))
# plt.figure(figsize=(12, 5))
# plt.subplot(121)
plt.scatter(bandit.corrs, mean_distances[0, 1:], label="Mean distance", c="tab:blue")
plt.fill_between(bandit.corrs,
                 mean_distances[0, 1:] - 1.96 * std_distances[0, 1:] / np.sqrt(NUM_SEEDS),
                 mean_distances[0, 1:] + 1.96 * std_distances[0, 1:] / np.sqrt(NUM_SEEDS), alpha=0.2, color="tab:blue")
plt.plot(bandit.corrs, coef_d[0] * bandit.corrs + coef_d[1], "--", c="tab:blue", alpha=0.5, label="Line of best fit")
plt.xticks(bandit.corrs)
plt.ylabel("Distance from anchor")
plt.xlabel("True Correlation $\\rho$")
plt.legend()

# plt.subplot(122)
# plt.scatter(bandit.corrs, rbf_weights, label="mean RBF weight", c="tab:orange")
# plt.fill_between(bandit.corrs, rbf_lower, rbf_upper, alpha=0.2, color="tab:orange")
# plt.plot(bandit.corrs, coef_w[0] * bandit.corrs + coef_w[1], "--", c="tab:orange", alpha=0.5, label="Line of best fit")
# plt.xticks(bandit.corrs)
# plt.ylabel("RBF Weights")
# plt.xlabel("True Correlation $\\rho$")
# plt.legend()

plt.tight_layout()

plt.savefig(args.exp_dir, dpi=100)
plt.close()
