import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import time
import json
import pickle
import argparse
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from datetime import datetime
from tqdm import tqdm, trange
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
from striatum.bandit import LinUCB, LinThompSamp
from striatum.storage import MemoryHistoryStorage, MemoryModelStorage, MemoryActionStorage, Action
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import OrdinalEncoder
from sklearn.utils import shuffle
from coba.learners import VowpalSquarecbLearner

from dataconverter import DataConverter
from model import C3
from neuralucb import NeuralUCBDiag
from neuralts import NeuralTS
from visualization import plot_embeddings, plot_cum_regret, plot_buffer_embeddings


# Setting up argparser
parser = argparse.ArgumentParser()
parser.add_argument("exp_dir", help="Directory to store results", type=str)
# parser.add_argument("dataset", help="OpenML dataset name", type=str)
# parser.add_argument("version", help="Version of dataset on OpenML", type=int)
parser.add_argument("algorithm", help="Choice of algorithm", choices=["c3", "linucb", "neuralucb", "lts", "neuralts", "squarecb"], type=str)
parser.add_argument("config", help="JSON file containing configurations file", type=str)
parser.add_argument("--seed", help="Seed number", default=42, type=int)
parser.add_argument("--standardized", help="Standardizes mean and standard deviation of input features", action=argparse.BooleanOptionalAction)
parser.add_argument("--skip_update", help="Skips updating data during online bandit evaluation", action=argparse.BooleanOptionalAction)
parser.add_argument("--use_oracle", help="Uses oracle's prediction instead of the ground truth", action=argparse.BooleanOptionalAction)

args = parser.parse_args()

# Setting up argparser global constants
SEED = args.seed
EXP_DIR = args.exp_dir
# VERSION = args.version
# DATASET = args.dataset
ALGORITHM = args.algorithm
CONFIG_PATH = args.config

# Managing file paths
get_path = lambda x: os.path.join(EXP_DIR, str(SEED), x)

done_dir = False
while not done_dir:
    try:
        if not os.path.exists(EXP_DIR):
            os.mkdir(EXP_DIR)
        done_dir = True
    except IOError:
        pass

if not os.path.exists(os.path.join(EXP_DIR, str(SEED))):
    os.mkdir(os.path.join(EXP_DIR, str(SEED)))
sys.stdout = open(get_path("log_eval.txt"), "w", buffering=1)

# Loading up config file
with open(CONFIG_PATH, "r") as f:
    config = json.load(f)
params = config["params"]

VERSION = config["VERSION"]
DATASET = config["DATASET"]
NUM_TEST_STEPS = config["NUM_TEST_STEPS"]
MAX_DATASET_SIZE = config["MAX_DATASET_SIZE"]
TRAIN_RATIO = config["TRAIN_RATIO"]
VAL_RATIO = config["VAL_RATIO"]
REWARD_PROP = config["REWARD_PROP"]

print(f"Running {ALGORITHM} on {DATASET} (v{VERSION}) - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# Setting up devices and seeds
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"DEVICE = {DEVICE} | skip_update = {args.skip_update}")

np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Loading data and metadata
data_bunch = fetch_openml(DATASET, version=VERSION, parser="auto")

NUM_ARMS = len(set(data_bunch["target"]))
CTX_SIZE = data_bunch["data"].shape[1]

if params["init"]["layer_nums"][0] is None:
    params["init"]["layer_nums"][0] = CTX_SIZE + NUM_ARMS

variables = {**locals()}
save_dict = {}
for k, v in variables.items():
    if len(k) >= 3 and k.upper() == k:
        try:
            json.dumps(v)
        except:
            continue
        save_dict[k] = v

with open(get_path("args.json"), "w") as f:
    json.dump(save_dict, f, indent=2)
with open(get_path("hyperparams.json"), "w") as f:
    json.dump(params, f, indent=2)

# Data preprocessing
X = data_bunch["data"].values.astype(float)

if args.use_oracle:
    with open(os.path.join("oracles", f"oracle_{DATASET}.pkl"), "rb") as f:
        oracle = pickle.load(f)
    
    y = oracle.predict(X).reshape(-1, 1)
else:
    y = OrdinalEncoder(dtype=int).fit_transform(data_bunch["target"].values.reshape(-1, 1)).astype(float)


if args.standardized:
    X = (X - X.mean(axis=0, keepdims=True)) / (X.std(axis=0, keepdims=True) + 1e-8)
elif DATASET == "mnist_784":
    X = X / 255.  # map pixels to [0, 1]

X, y = shuffle(X, y, random_state=SEED)
X = X[:MAX_DATASET_SIZE]
y = y[:MAX_DATASET_SIZE]


# Train test splits and data converters
X, X_val, y, y_val = train_test_split(X, y, train_size=TRAIN_RATIO, random_state=SEED)
X_val, X_test, y_val, y_test = train_test_split(X_val, y_val, train_size=VAL_RATIO, random_state=SEED*2)

dc = DataConverter(X, y, NUM_ARMS)
C_data, A_data, R_data, true_label = dc.get_pretraining_split(proportion=1, 
                                                  reward_prop=REWARD_PROP,
                                                  seed=SEED)

# dc_train = DataConverter(X, y)

dc_val = DataConverter(X_val, y_val, NUM_ARMS)
C_val_data, A_val_data, R_val_data, _ = dc_val.get_pretraining_split(proportion=1, 
                                                                  reward_prop=REWARD_PROP,
                                                                  seed=SEED)

dc_test = DataConverter(X_test, y_test, NUM_ARMS)

X_train = torch.Tensor(np.hstack([C_data, np.eye(NUM_ARMS)[A_data]]))
y_train = torch.Tensor(R_data.reshape(-1, 1))
X_val = torch.Tensor(np.hstack([C_val_data, np.eye(NUM_ARMS)[A_val_data]]))
y_val = torch.Tensor(R_val_data.reshape(-1, 1))

print("X_train.shape, y_train.shape, y_train.mean(), X_val.shape, y_val.shape, y_val.mean()")
print(X_train.shape, y_train.shape, y_train.mean().item(), X_val.shape, y_val.shape, y_val.mean().item())

all_actions = list()
cum_regret = [0]
time_steps = list()

freq_eval = len(C_data) // len(dc_val)
start_time = time.time()

if ALGORITHM == "c3":
    # Phi training
    print(f"Training C3Model - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    layer_nums = params["init"]["layer_nums"]
    
    # model = C3(X_init=X_train, y_init=y_train, seed=SEED, **params["init"])
    model = C3(seed=SEED, **params["init"])

    model.fit(X_train, y_train, None, DEVICE, X_val=X_val, y_val=y_val, prob_val=None, 
            plot=True, save_plot_name=get_path("loss_plot.png"), **params["fit"])

    # torch.save(model, get_path("c3model_offline.pt"))

    print(f"Getting training regrets - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    model.clear_buffer(seed=SEED)
    C, C_arms = dc_val.reset(seed=SEED)

    for i in trange(len(dc_val)):
        action, mean, stderr, w_new = model.infer_batch(torch.Tensor(C_arms))
        reward, regret = dc_val.take_action(action)

        model.store_buffer(torch.Tensor(C_arms[A_val_data[i]]).view(1, -1),
                        torch.Tensor([R_val_data[i]]).view(1, 1),
                        torch.Tensor(w_new))
        model.store_buffer(torch.Tensor(X_train[i*freq_eval:(i+1)*freq_eval]), 
                            torch.Tensor(y_train[i*freq_eval:(i+1)*freq_eval]))
        
        all_actions.append(action)
        cum_regret.append(cum_regret[-1] + regret)
        time_steps.append(time.time() - start_time)

        if i != len(dc_val) - 1:
            C, C_arms = dc_val.next()
        if i == NUM_TEST_STEPS:
            break

    model.X_init = X_train
    model.y_init = y_train

    # Generating resultant embedding plots
    print(f"Generating embedding plot - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    with torch.no_grad():
        emb_red = TSNE(perplexity=35, random_state=SEED).fit_transform(model.project(X_train[:1000]).numpy())
    np.save(get_path("embred_offline.npy"), emb_red)
    plot_embeddings(emb_red, y_train, A_data, R_data, NUM_ARMS, save_path=get_path("embedding_offline.png"))

    # print(f"Starting online bandit evaluation - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [{int(time.time() - start_time)}] sec")
    # model.clear_buffer(seed=SEED)
    # C, C_arms = dc_test.reset(seed=SEED)

    # for i in trange(NUM_TEST_STEPS):
    #     action, mean, stderr, w_new = model.infer_batch(torch.Tensor(C_arms))
    #     reward, regret = dc_test.take_action(action)

    #     if not args.skip_update:
    #         model.store_buffer(torch.Tensor(C_arms[action]).view(1, -1),
    #                         torch.Tensor([reward]).view(1, 1),
    #                         torch.Tensor(w_new))
        
    #     all_actions.append(action)
    #     cum_regret.append(cum_regret[-1] + regret)
    #     time_steps.append(time.time() - start_time)

    #     C, C_arms = dc_test.next()
        
elif ALGORITHM == "linucb":
    linucb = LinUCB(MemoryHistoryStorage(), MemoryModelStorage(), MemoryActionStorage(), context_dimension=CTX_SIZE, alpha=1.96)
    linucb.add_action([Action() for _ in range(NUM_ARMS)])

    C, C_arms = dc_val.reset(seed=SEED)    

    print(f"Pretraining LinUCB - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    for i in range(len(C_data)):
        history_id, action_obj = linucb.get_action({j: C_data[i] for j in range(NUM_ARMS)})
        linucb.reward(history_id, {A_data[i]: R_data[i]})
        if i % 2000 == 0:
            print(f"\tStep {i} / {len(C_data)}")

        # if i < len(dc_val):
        if i % freq_eval == 0:
            j = i // freq_eval
            if j >= len(dc_val):
                continue
            history_id, action_obj = linucb.get_action({j: C for j in range(NUM_ARMS)})
            action = action_obj.action.id
            reward, regret = dc_val.take_action(action)

            linucb.reward(history_id, {A_val_data[j]: R_val_data[j]})

            all_actions.append(action)
            cum_regret.append(cum_regret[-1] + regret)
            time_steps.append(time.time() - start_time)
            
            if j != len(dc_val) - 1:
                C, C_arms = dc_val.next()

            if j == NUM_TEST_STEPS:
                break

    # torch.save(linucb, get_path("linucb_offline.pt"))

    # print(f"Starting online bandit evaluation - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [{int(time.time() - start_time)} sec]")
    # C, C_arms = dc_test.reset(seed=SEED)

    # for i in trange(NUM_TEST_STEPS):
    #     history_id, action_obj = linucb.get_action({j: C for j in range(NUM_ARMS)})

    #     action = action_obj.action.id
    #     reward, regret = dc_test.take_action(action)

    #     if not args.skip_update:
    #         linucb.reward(history_id, {action: reward})

    #     all_actions.append(action)
    #     cum_regret.append(cum_regret[-1] + regret)
    #     time_steps.append(time.time() - start_time)

    #     C, C_arms = dc_test.next()

elif ALGORITHM == "lts":
    lints = LinThompSamp(MemoryHistoryStorage(), MemoryModelStorage(), MemoryActionStorage(), 
                         epsilon=1/np.log(len(X_train) + len(X_test)), 
                         context_dimension=CTX_SIZE, random_state=SEED)
    lints.add_action([Action() for _ in range(NUM_ARMS)])

    C, C_arms = dc_val.reset(seed=SEED)

    print(f"Pretraining LinTS - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    for i in range(len(C_data)):
        history_id, action_obj = lints.get_action({j: C_data[i] for j in range(NUM_ARMS)})
        lints.reward(history_id, {A_data[i]: R_data[i]})
        if i % 2000 == 0:
            print(f"\tStep {i} / {len(C_data)}")

        if i % freq_eval == 0:
            j = i // freq_eval
            if j >= len(dc_val):
                continue
            history_id, action_obj = lints.get_action({j: C for j in range(NUM_ARMS)})
            action = action_obj.action.id
            reward, regret = dc_val.take_action(action)

            lints.reward(history_id, {A_val_data[j]: R_val_data[j]})

            all_actions.append(action)
            cum_regret.append(cum_regret[-1] + regret)
            time_steps.append(time.time() - start_time)
            
            if j != len(dc_val) - 1:
                C, C_arms = dc_val.next()
            if j == NUM_TEST_STEPS:
                break
    # torch.save(lints, get_path("lints_offline.pt"))

    # print(f"Starting online bandit evaluation - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [{int(time.time() - start_time)} sec]")
    # C, C_arms = dc_test.reset(seed=SEED)

    # for i in trange(NUM_TEST_STEPS):
    #     history_id, action_obj = lints.get_action({j: C for j in range(NUM_ARMS)})

    #     action = action_obj.action.id
    #     reward, regret = dc_test.take_action(action)

    #     if not args.skip_update:
    #         lints.reward(history_id, {action: reward})

    #     all_actions.append(action)
    #     cum_regret.append(cum_regret[-1] + regret)
    #     time_steps.append(time.time() - start_time)

    #     C, C_arms = dc_test.next()
elif ALGORITHM == "squarecb":
    gamma = 10
    ALL_ACTIONS = list(range(NUM_ARMS))

    oracle = VowpalSquarecbLearner(gamma_scale=gamma)

    C, C_arms = dc_val.reset(seed=SEED)

    print(f"Pretraining LinTS - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    for i in range(len(C_data)):
        print(C_data[i].tolist())
        print([A_data[i]])
        action, prob, _ = oracle.predict(C_data[i].tolist(), [int(A_data[i])])
        oracle.learn(C_data[i].tolist(), action, R_data[i], prob)

        if i % freq_eval == 0:
            j = i // freq_eval
            if j >= len(dc_val):
                continue
            action, prob, _ = oracle.predict(C.tolist(), ALL_ACTIONS)
            reward, regret = dc_val.take_action(action)
            oracle.learn(C.tolist(), action, reward, prob)

            all_actions.append(action)
            cum_regret.append(cum_regret[-1] + regret)
            time_steps.append(time.time() - start_time)
            
            if j != len(dc_val) - 1:
                C, C_arms = dc_val.next()
            if j == NUM_TEST_STEPS:
                break

elif ALGORITHM == "neuralts":
    neuralts = NeuralTS(dim=CTX_SIZE * NUM_ARMS, nu=0.00001, lamdba=0.00001)
    
    C_val, C_arms_val = dc_val.reset(seed=SEED)

    print(f"Pretraining NeuralTS - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    for i, (C, A, R) in enumerate(zip(C_data, A_data, R_data)):
        C_prime = NeuralTS.convert_data(C, NUM_ARMS)
        neuralts.force_select(C_prime, A)
        neuralts.train(C_prime[A], R)
        if i % 2000 == 0:
            print(f"\tStep {i} / {len(C_data)}")

        if i % freq_eval == 0:
            j = i // freq_eval
            if j >= len(dc_val):
                continue
            C_prime_val = NeuralTS.convert_data(C_val, dc_val.num_arms)
            action, *_ = neuralts.select(C_prime_val)
            reward, regret = dc_val.take_action(action)

            neuralts.train(C_prime[A_val_data[j]], R_val_data[j])

            all_actions.append(action)
            cum_regret.append(cum_regret[-1] + regret)
            time_steps.append(time.time() - start_time)
            
            if j != len(dc_val) - 1:
                C_val, C_arms_val = dc_val.next()
            if j == NUM_TEST_STEPS:
                break

    # print(f"Starting online bandit evaluation - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    # C, C_arms = dc_test.reset(seed=SEED)

    # for i in trange(NUM_TEST_STEPS):
    #     C_prime = NeuralTS.convert_data(C, dc_test.num_arms)

    #     action, *_ = neuralts.select(C_prime)
    #     reward, regret = dc_test.take_action(action)

    #     if not args.skip_update:
    #         neuralts.train(C_prime[action], reward)

    #     all_actions.append(action)
    #     cum_regret.append(cum_regret[-1] + regret)
    #     time_steps.append(time.time() - start_time)

    #     C, C_arms = dc_test.next()

else:
    neuralucb = NeuralUCBDiag(CTX_SIZE * NUM_ARMS, nu=0.00001, lamdba=0.00001)

    C_val, C_arms_val = dc_val.reset(seed=SEED)

    print(f"Pretraining NeuralUCB - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [{int(time.time() - start_time)} sec]")
    for i, (C, A, R) in enumerate(zip(C_data, A_data, R_data)):
        C_prime = NeuralUCBDiag.convert_data(C, NUM_ARMS)
        neuralucb.force_select(C_prime, A)
        neuralucb.train(C_prime[A], R)
        if i % 2000 == 0:
            print(f"\tStep {i} / {len(C_data)}")

        if i % freq_eval == 0:
            j = i // freq_eval
            if j >= len(dc_val):
                continue
            C_prime_val = NeuralUCBDiag.convert_data(C_val, dc_val.num_arms)
            action, *_ = neuralucb.select(C_prime_val)
            reward, regret = dc_val.take_action(action)

            neuralucb.train(C_prime[A_val_data[j]], R_val_data[j])

            all_actions.append(action)
            cum_regret.append(cum_regret[-1] + regret)
            time_steps.append(time.time() - start_time)
            
            if j != len(dc_val) - 1:
                C_val, C_arms_val = dc_val.next()
            if j == NUM_TEST_STEPS:
                break

    # torch.save(neuralucb, get_path("neuralucb_offline.pt"))
    
    # print(f"Starting online bandit evaluation - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    # C, C_arms = dc_test.reset(seed=SEED)

    # for i in trange(NUM_TEST_STEPS):
    #     C_prime = NeuralUCBDiag.convert_data(C, dc_test.num_arms)

    #     action, *_ = neuralucb.select(C_prime)
    #     reward, regret = dc_test.take_action(action)

    #     if not args.skip_update:
    #         neuralucb.train(C_prime[action], reward)

    #     all_actions.append(action)
    #     cum_regret.append(cum_regret[-1] + regret)
    #     time_steps.append(time.time() - start_time)

    #     C, C_arms = dc_test.next()

print(f"Done online bandit evaluation - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}  [{int(time.time() - start_time)} sec]")

results_dict = {
    "cum_regret": cum_regret,
    "all_actions": all_actions,
    "time_steps": time_steps
}

with open(get_path("results_dict.pkl"), "wb") as f:
    pickle.dump(results_dict, f)

plot_cum_regret([cum_regret], [all_actions], [ALGORITHM], NUM_ARMS,
                save_path=get_path("results.png"))

# if ALGORITHM == "c3":
#     with torch.no_grad():
#         emb_red = TSNE(perplexity=35, random_state=SEED).fit_transform(model.X_buff[len(X_train):].numpy())
#     plot_buffer_embeddings(emb_red, model, NUM_ARMS, start_index=len(X_train), save_path=get_path("embedding_buffer.png"))
#     np.save(get_path("embred_buffer.npy"), emb_red)

print(f"Execution complete - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

sys.stdout.close()
