
import gpytorch
import numpy as np
import torch
import torch.utils.data as data_utils
import tqdm
import argparse
import sys
import os
import time
from gpytorch.kernels import ScaleKernel, MaternKernel, RBFKernel, LinearKernel
from gpytorch.means import ConstantMean, ZeroMean
from bayesopt.surrogates.gp_default import DefaultMLLGP
from bayesopt.surrogates.gp import MLLGP
from bayesopt.surrogates.laplace import FixedFeatLaplace
from bayesopt.surrogates.fsplaplace import FixedFeatFSPLaplace
from bayesopt.acqf import ucb, ei, thompson_sampling
from utils import helpers
from sklearn.utils import shuffle as skshuffle
from utils.configs import LaplaceConfig, FSPLaplaceConfig 
from sklearn.preprocessing import StandardScaler

from gauche.kernels.fingerprint_kernels import TanimotoKernel


ACTIVATIONS = {
    "relu": torch.nn.ReLU,
    "tanh": torch.nn.Tanh,
    "lrelu": torch.nn.LeakyReLU,
}        

parser = argparse.ArgumentParser()
parser.add_argument(
    "--problem",
    choices=["redox-mer", "solvation", "kinase", "laser", "pce", "photoswitch", "ampc", "d4"],
    default="redox-mer",
)
parser.add_argument("--method", choices=["random", "gp", "laplace", "fsplaplace", "gp_default"])
parser.add_argument(
    "--feature_type",
    choices=[
        "fingerprints",
        "molformer",
        "t5-base-chem",
        "mordred",
        "degree_of_conjugation",
        "force_field",
        "dft", 
        "all_features",
        "hand_crafted_expert", 
        "hand_crafted_general", 
        "data_driven"
    ],
    default="fingerprints",
)
parser.add_argument(
    "--kernel",
    choices=[
        "tanimoto",
        "matern_0.5",
        "matern_1.5",
        "matern_2.5",
        "rbf",
        "linear",
    ],
    default="matern_2.5",
)
parser.add_argument(
    "--activation_fn", choices=["relu", "lrelu", "tanh"], default="tanh"
)
parser.add_argument(
    "--noise_var", default="0.0001"
)
parser.add_argument(
    "--wd", default=1e-1
)
parser.add_argument(
    "--map_context_points", choices=["uniform", "bo_candidates"], default="bo_candidates"
)
parser.add_argument(
    "--cov_context_points", choices=["sobol", "bo_candidates"], default="bo_candidates"
)
parser.add_argument(
    "--lr", default=1e-3
)
parser.add_argument(
    "--mean_fn", choices=["zero", "constant"], default="zero"
)
parser.add_argument(
    "--init_high_dim", default=True, action="store_true"
)
parser.add_argument(
    "--task", default="param_sensitivity"
)
parser.add_argument("--acqf", choices=["ei", "ucb", "ts"], default="ts")
parser.add_argument("--n_init_data", type=int, default=10)
parser.add_argument("--exp_len", type=int, default=200)
parser.add_argument("--dtype", choices=["float32", "float64"], default="float64")
parser.add_argument("--cuda", default=False, action="store_true")
parser.add_argument("--randseed", type=int, default=1)
parser.add_argument("--normalize_x", default=False, action="store_true")
parser.add_argument("--normalize_y", default=False, action="store_true")
args = parser.parse_args()

args.lr = float(args.lr)
args.wd = float(args.wd)
args.noise_var = float(args.noise_var)

np.random.seed(args.randseed)
torch.manual_seed(args.randseed)

if not args.init_high_dim:
    print()
    print(f"Beware, you are using default initialization which performs very poorly in high dimensions!")
    print()

if torch.cuda.is_available():
    print("Running on GPU")
else:
    print("Running on CPU")

if args.cuda:
    if not torch.cuda.is_available():
        print("No CUDA detected!")
        sys.exit(1)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32 if args.dtype == "float32" else torch.float64

if args.problem == "redox-mer":
    MAXIMIZATION = False
elif args.problem == "solvation":
    MAXIMIZATION = False
elif args.problem == "kinase":
    MAXIMIZATION = False
elif args.problem == "laser":
    MAXIMIZATION = True
elif args.problem == "pce":
    MAXIMIZATION = True
elif args.problem == "photoswitch":
    MAXIMIZATION = True
elif args.problem == "ampc":      
    MAXIMIZATION = False
elif args.problem == "d4":  
    MAXIMIZATION = False
else:
    print("Invalid test function!")
    sys.exit(1)

# Dataset
CACHE_PATH = f"data/cache/{args.problem}/"
if args.feature_type == "all_features":
    feature_list = ["fingerprints", "molformer", "t5-base-chem", "mordred", "degree_of_conjugation", "force_field", "dft"]
    targets = torch.load(CACHE_PATH + f"fingerprints_targets.bin")
    features = []
    for feature in feature_list:
        tmp = torch.load(CACHE_PATH + f"{feature}_feats.bin")
        if args.normalize_x and feature != "fingerprints":
            x_preprocessor = StandardScaler()
            tmp = list(
                torch.tensor(
                    x_preprocessor.fit_transform(np.array(tmp).reshape(-1, tmp[0].shape[-1]))
                ).float()
            )
        features.append(tmp)
    features = [torch.cat([fp, mf, t5, md, dc, ff, dft], dim=0) for fp, mf, t5, md, dc, ff, dft in zip(*features)]
elif args.feature_type == "hand_crafted_expert":
    feature_list = ["degree_of_conjugation", "force_field", "dft"]
    targets = torch.load(CACHE_PATH + f"fingerprints_targets.bin")
    features = []
    for feature in feature_list:
        tmp = torch.load(CACHE_PATH + f"{feature}_feats.bin")
        if args.normalize_x:
            print("Normalizing features")
            x_preprocessor = StandardScaler()
            tmp = list(
                torch.tensor(
                    x_preprocessor.fit_transform(np.array(tmp).reshape(-1, tmp[0].shape[-1]))
                ).float()
            )
        features.append(tmp)

    features = [torch.cat([dc, ff, dft], dim=0) for dc, ff, dft in zip(*features)]
elif args.feature_type == "data_driven":
    feature_list = ["molformer", "t5-base-chem"]
    targets = torch.load(CACHE_PATH + f"fingerprints_targets.bin")
    features = []
    for feature in feature_list:
        tmp = torch.load(CACHE_PATH + f"{feature}_feats.bin")
        if args.normalize_x:
            print("Normalizing features")
            x_preprocessor = StandardScaler()
            tmp = list(
                torch.tensor(
                    x_preprocessor.fit_transform(np.array(tmp).reshape(-1, tmp[0].shape[-1]))
                ).float()
            )
        features.append(tmp)
    features = [torch.cat([mf, t5], dim=0) for mf, t5 in zip(*features)]
elif args.feature_type == "hand_crafted_general":
    feature_list = ["fingerprints", "mordred"]
    targets = torch.load(CACHE_PATH + f"fingerprints_targets.bin")
    features = []
    for feature in feature_list:
        tmp = torch.load(CACHE_PATH + f"{feature}_feats.bin")
        if args.normalize_x and feature != "fingerprints":
            print("Normalizing features")
            x_preprocessor = StandardScaler()
            tmp = list(
                torch.tensor(
                    x_preprocessor.fit_transform(np.array(tmp).reshape(-1, tmp[0].shape[-1]))
                ).float()
            )
        features.append(tmp)
    features = [torch.cat([fp, md], dim=0) for fp, md, in zip(*features)]
else:
    features = torch.load(CACHE_PATH + f"{args.feature_type}_feats.bin")
    targets = torch.load(CACHE_PATH + f"{args.feature_type}_targets.bin")

    # Normalize features
    if args.normalize_x and args.feature_type != "fingerprints": 
        print("Normalizing features")
        x_preprocessor = StandardScaler()
        features = list(
            torch.tensor(
                x_preprocessor.fit_transform(np.array(features).reshape(-1, features[0].shape[-1]))
            ).float()
        )


features, targets = skshuffle(features, targets, random_state=args.randseed)
feature_dim = features[0].shape[-1]
ground_truth_max = torch.tensor(targets).flatten().max()

print()
if args.method == "gp" or args.method == "gp_default":
    print(
        f"Test Function: {args.problem}; Feature Type: {args.feature_type}; Randseed: {args.randseed}; kernel: {args.kernel} - lr: {args.lr} - noise_var: {args.noise_var}"
    )
elif args.method == "fsplaplace":
    print(
        f"Test Function: {args.problem}; Feature Type: {args.feature_type}; Randseed: {args.randseed}; kernel: {args.kernel}; activation: {args.activation_fn}; map_context_points: {args.map_context_points}; cov_context_points: {args.cov_context_points}; lr: {args.lr} - noise_var: {args.noise_var};"
    )
elif args.method == "laplace":
    print(
        f"Test Function: {args.problem}; Feature Type: {args.feature_type}; Randseed: {args.randseed}; wd: {args.wd}; activation: {args.activation_fn}; lr: {args.lr} - noise_var: {args.noise_var};"
    )
else:
    print(
        f"Test Function: {args.problem}; Feature Type: {args.feature_type}; Randseed: {args.randseed};"
    )
print(
    "-------------------------------------------------------------------------------------------------------"
)
print()

train_x, train_y = [], []
while len(train_x) < args.n_init_data:
    idx = np.random.randint(len(features))
    # Make sure that the optimum is not included
    if targets[idx].item() >= ground_truth_max:
        continue
    train_x.append(features.pop(idx))
    train_y.append(targets.pop(idx))
train_x, train_y = torch.stack(train_x), torch.stack(train_y)
train_x, train_y = train_x.to(DTYPE), train_y.to(DTYPE)

if args.method == "laplace":
    cfg = LaplaceConfig(wd=args.wd, lr=args.lr, activation=args.activation_fn, noise_var=args.noise_var)
    activation = ACTIVATIONS[args.activation_fn]
    def get_net():
        return torch.nn.Sequential(
            torch.nn.Linear(feature_dim, 50),
            activation(),
            torch.nn.Linear(50, 50),
            activation(),
            torch.nn.Linear(50, 1),
        ).to(DTYPE)
    model = FixedFeatLaplace(
        train_x,
        train_y,
        get_net,
        hess_factorization="kron",
        laplace_config=cfg,
        dtype=DTYPE,
        normalize_x=False #normalize_x=(args.feature_type != "fingerprints"),
    )
    model = model.to(DEVICE).to(DTYPE)
elif args.method == "fsplaplace":
    cfg = FSPLaplaceConfig(
        cov_context_points=args.cov_context_points, 
        map_context_points=args.map_context_points, 
        lr=args.lr,
        activation=args.activation_fn,
        noise_var=args.noise_var
    )
    activation = ACTIVATIONS[args.activation_fn]
    def get_net():
        return torch.nn.Sequential(
            torch.nn.Linear(feature_dim, 50),
            activation(),
            torch.nn.Linear(50, 50),
            activation(),
            torch.nn.Linear(50, 1),
        ).to(DTYPE).to(DEVICE)
    
    mean_fn = ZeroMean(batch_shape=torch.Size([1]))
    if args.feature_type in ["all_features", "hand_crafted_general"]:
        # Fingerprint kernel
        active_dims = range(1024)
        fingerprint_kernel = TanimotoKernel(active_dims=active_dims) 
        # Global structure kernel
        active_dims = range(1024, feature_dim)
        if "matern" in args.kernel:
            nu = float(args.kernel.split("_")[-1])
            other_feature_kernel = MaternKernel(nu=nu, ard_num_dims=len(active_dims), active_dims=active_dims, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "rbf":
            other_feature_kernel = RBFKernel(ard_num_dims=len(active_dims), active_dims=active_dims, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "linear":
            other_feature_kernel = LinearKernel(batch_shape=torch.Size([1]))
        else: 
            raise ValueError(f"Invalid kernel: {args.kernel}")
        
        if args.kernel != "linear":
            kernel_fn = (
                ScaleKernel(other_feature_kernel, batch_shape=torch.Size([1])) + ScaleKernel(fingerprint_kernel, batch_shape=torch.Size([1]))
            )
        else:
            kernel_fn = (
                other_feature_kernel + ScaleKernel(fingerprint_kernel, batch_shape=torch.Size([1]))
            )
    else:
        if "matern" in args.kernel:
            nu = float(args.kernel.split("_")[-1])
            kernel_fn = MaternKernel(nu=nu, ard_num_dims=feature_dim, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "rbf":
            kernel_fn = RBFKernel(ard_num_dims=feature_dim, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "tanimoto":
            kernel_fn = TanimotoKernel()
        elif args.kernel == "linear":
            kernel_fn = LinearKernel(batch_shape=torch.Size([1]))
        else: 
            raise ValueError(f"Invalid kernel: {args.kernel}")
        kernel_fn = ScaleKernel(kernel_fn, batch_shape=torch.Size([1])).to(DTYPE).to(DEVICE) if args.kernel not in ["linear"] else kernel_fn
        
    model = FixedFeatFSPLaplace(
        train_x,
        train_y,
        torch.stack(features), # candidates
        mean_fn.to(DTYPE).to(DEVICE),
        kernel_fn,
        initialize_nn=get_net,
        laplace_config=cfg,
        dtype=DTYPE, 
        normalize_x=False# (args.feature_type != "fingerprints"),
    )
    model = model.to(DEVICE).to(DTYPE)
elif args.method == "gp":
    if args.feature_type in ["all_features", "hand_crafted_general"]:
        # Fingerprint kernel
        active_dims = range(1024)
        fingerprint_kernel = TanimotoKernel(active_dims=active_dims) 
        # Global structure kernel
        active_dims = range(1024, feature_dim)
        if "matern" in args.kernel:
            nu = float(args.kernel.split("_")[-1])
            other_feature_kernel = MaternKernel(nu=nu, ard_num_dims=len(active_dims), active_dims=active_dims, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "rbf":
            other_feature_kernel = RBFKernel(ard_num_dims=len(active_dims), active_dims=active_dims, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "linear":
            other_feature_kernel = LinearKernel(variance_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        else: 
            raise ValueError(f"Invalid kernel: {args.kernel}")
        
        if args.kernel != "linear":
            kernel_fn = ScaleKernel(other_feature_kernel) + ScaleKernel(fingerprint_kernel)
        else:
            kernel_fn = other_feature_kernel + ScaleKernel(fingerprint_kernel)
    else:
        if "matern" in args.kernel:
            nu = float(args.kernel.split("_")[-1])
            kernel_fn = MaternKernel(nu=nu, ard_num_dims=feature_dim, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "rbf":
            kernel_fn = RBFKernel(ard_num_dims=feature_dim, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "tanimoto":
            kernel_fn = TanimotoKernel()
        elif args.kernel == "linear":
            kernel_fn = LinearKernel()
        else: 
            raise ValueError(f"Invalid kernel: {args.kernel}")
    model = MLLGP(
        train_X=train_x, 
        train_Y=train_y, 
        kernel=ScaleKernel(kernel_fn) if args.feature_type not in ["all_features", "hand_crafted_general"] and args.kernel != "linear" else kernel_fn, 
        noise_var=args.noise_var, 
        device=DEVICE, 
        dtype=DTYPE, 
        normalize_x=False, # (args.feature_type != "fingerprints")
    )
    model = model.to(DEVICE).to(DTYPE)
elif args.method == "gp_default":
    if args.feature_type in ["all_features", "hand_crafted_general"]:
        # Fingerprint kernel
        active_dims = range(1024)
        fingerprint_kernel = TanimotoKernel(active_dims=active_dims) 
        # Global structure kernel
        active_dims = range(1024, feature_dim)
        if "matern" in args.kernel:
            nu = float(args.kernel.split("_")[-1])
            other_feature_kernel = MaternKernel(nu=nu, ard_num_dims=len(active_dims), active_dims=active_dims) #, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "rbf":
            other_feature_kernel = RBFKernel(ard_num_dims=len(active_dims), active_dims=active_dims) # , lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "linear":
            other_feature_kernel = LinearKernel()
        else: 
            raise ValueError(f"Invalid kernel: {args.kernel}")
        
        if args.kernel != "linear":
            kernel_fn = ScaleKernel(other_feature_kernel) + ScaleKernel(fingerprint_kernel)
        else:
            kernel_fn = other_feature_kernel + ScaleKernel(fingerprint_kernel)
    else:
        if "matern" in args.kernel:
            nu = float(args.kernel.split("_")[-1])
            kernel_fn = MaternKernel(nu=nu, ard_num_dims=feature_dim) # , lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "rbf":
            kernel_fn = RBFKernel(ard_num_dims=feature_dim) #, lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3))
        elif args.kernel == "tanimoto":
            kernel_fn = TanimotoKernel()
        elif args.kernel == "linear":
            kernel_fn = LinearKernel()
        else: 
            raise ValueError(f"Invalid kernel: {args.kernel}")
    model = DefaultMLLGP(
        train_X=train_x, 
        train_Y=train_y, 
        mean=ZeroMean() if args.mean_fn == "zero" else ConstantMean(),
        kernel=ScaleKernel(kernel_fn) if args.feature_type not in ["all_features", "hand_crafted_general"] and args.kernel != "linear" else kernel_fn, 
        noise_var=args.noise_var, 
        lr=args.lr,
        device=DEVICE, 
        dtype=DTYPE, 
        normalize_x=False, # (args.feature_type != "fingerprints")
        normalize_y=args.normalize_y, 
        init_high_dim=args.init_high_dim
    )
    model = model.to(DEVICE).to(DTYPE)
else:  # Random search
    model = None

best_y = train_y.max().item()

pbar = tqdm.trange(args.exp_len, file=sys.stdout, disable=True)
print(f"0/{args.exp_len} - [Best f(x) = {helpers.y_transform(best_y, MAXIMIZATION):.3f}]", flush=True)

trace_best_y = [helpers.y_transform(ground_truth_max, MAXIMIZATION)] * (
    args.exp_len + 1
)
trace_timing = [0.0] * (args.exp_len + 1)
trace_best_acqval = []

timing_train = []
timing_preds = []

for i in pbar:
    if DEVICE == "cuda":
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()
        start.record()
    else:
        start = time.time()

    if args.method == "random":
        idx = np.random.randint(len(features))
        new_x = features.pop(idx)
        new_y = targets.pop(idx)
    else:
        dataloader = data_utils.DataLoader(
            data_utils.TensorDataset(torch.stack(features), torch.stack(targets)),
            batch_size=256,
            shuffle=False,
        )

        preds, uncerts, labels = [], [], []
        acq_vals = []
        start_pred = time.time()

        for x, y in dataloader:
            x, y = x.to(DEVICE).to(DTYPE), y.to(DEVICE).to(DTYPE)
            posterior = model.posterior(x)
            f_mean, f_var = posterior.mean, posterior.variance

            if args.acqf == "ei":
                acq_vals.append(ei(f_mean, f_var, best_y))
            elif args.acqf == "ucb":
                acq_vals.append(ucb(f_mean, f_var))
            else:
                acq_vals.append(thompson_sampling(f_mean, f_var))

            preds.append(f_mean)
            uncerts.append(f_var.sqrt())
            labels.append(y)

        timing_preds.append(time.time() - start_pred)

        acq_vals = torch.cat(acq_vals, dim=0).cpu().squeeze()
        preds, uncerts, labels = (
            torch.cat(preds, dim=0).cpu(),
            torch.cat(uncerts, dim=0).cpu(),
            torch.cat(labels, dim=0).cpu(),
        )
        test_loss = torch.nn.MSELoss()(preds, labels).item()

        # Pick a molecule (a row in the current dataset) that maximizes the acquisition
        idx_best = torch.argmax(acq_vals).item()
        new_x, new_y = features.pop(idx_best), targets.pop(idx_best)

        trace_best_acqval.append(torch.max(acq_vals).item())

    # Update the current best y
    if new_y.item() > best_y:
        best_y = new_y.item()

    if args.method == "random":
        print(
            f"{i}/{args.exp_len} - [Best f(x) = {helpers.y_transform(best_y, MAXIMIZATION):.3f}, "
            + f"curr f(x) = {helpers.y_transform(new_y.item(), MAXIMIZATION):.3f}]", 
            flush=True
        )
    else:
        print(
            f"{i}/{args.exp_len} - [Best f(x) = {helpers.y_transform(best_y, MAXIMIZATION):.3f}, "
            + f"curr f(x) = {helpers.y_transform(new_y.item(), MAXIMIZATION):.3f}, test MSE = {test_loss:.3f}]", 
            flush=True
        )

        # Update surrogate
        start_train = time.time()
        model = model.condition_on_observations(new_x.unsqueeze(0), new_y.unsqueeze(0), idx_best=idx_best)
        timing_train.append(time.time() - start_train)

    # Housekeeping
    if DEVICE == "cuda":
        end.record()
        torch.cuda.synchronize()
        timing = start.elapsed_time(end) / 1000
    else:
        timing = time.time() - start

    trace_best_y[i + 1] = helpers.y_transform(best_y, MAXIMIZATION)
    trace_timing[i + 1] = timing

    # Early stopping if we already got the max
    if best_y >= ground_truth_max:
        break   

# Save results
path = f"results/{args.task}/{args.problem}/{args.method}/{args.feature_type}"

if not os.path.exists(path):
    os.makedirs(path)

if args.method == "gp":
    suffix = f"_{args.kernel}_{args.noise_var}"
elif args.method == "gp_default":
    suffix = f"_{args.kernel}_{args.noise_var}_{args.lr}_{args.mean_fn}_{args.init_high_dim}"
    suffix += '_norm_x' if args.normalize_x else '_not_norm_x'
    suffix += '_norm_y' if args.normalize_y else '_not_norm_y'
elif args.method == "fsplaplace":
    suffix = f"_{args.kernel}_{args.activation_fn}_{args.map_context_points}_{args.cov_context_points}_{args.lr}_{args.noise_var}"
elif args.method == "laplace":
    suffix = f"_{args.wd}_{args.activation_fn}_{args.lr}_{args.noise_var}"
else:
    suffix = ""

np.save(
    f"{path}/timing_preds_{args.n_init_data}_{args.acqf}_{args.randseed}{suffix}.npy",
    timing_preds,
)
np.save(
    f"{path}/trace_best_acqval_{args.n_init_data}_{args.acqf}_{args.randseed}{suffix}.npy",
    trace_best_acqval,
)
np.save(
    f"{path}/trace_best_y_{args.n_init_data}_{args.acqf}_{args.randseed}{suffix}.npy",
    trace_best_y,
)
np.save(
    f"{path}/trace_timing_{args.n_init_data}_{args.acqf}_{args.randseed}{suffix}.npy",
    trace_timing,
)