import numpy as np
import pandas as pd
import torch
import tqdm
import argparse
import sys
import os
from foundation_models import (
    MolFormerRegressor,
    RobertaRegressor,
    T5Regressor,
    GPT2Regressor,
    Llama2Regressor,
)
from foundation_models import (
    get_molformer_tokenizer,
    get_roberta_tokenizer,
    get_t5_tokenizer,
    get_gpt2_tokenizer,
    get_llama2_tokenizer,
)
from llm_bayesopt import LLALoRALLMBayesOpt, FSPLLALoRALLMBayesOpt
from bayesopt.acqf import ucb, ei, thompson_sampling
from problems.data_processor import (
    RedoxDataProcessor,
    SolvationDataProcessor,
    KinaseDockingDataProcessor,
    LaserEmitterDataProcessor,
    PhotovoltaicsPCEDataProcessor,
    PhotoswitchDataProcessor,
    D4DockingDataProcessor, 
    AmpCDockingDataProcessor
)
from problems.prompting import PromptBuilder
from utils import helpers
from utils.configs import LoraLaplaceConfig, LLMFeatureType, LoraFSPLaplaceConfig
from peft import LoraConfig, get_peft_model
import math

from gpytorch.kernels import ScaleKernel, MaternKernel, RBFKernel, LinearKernel
from gauche.kernels.fingerprint_kernels import TanimotoKernel
from gpytorch.means import ZeroMean

from gpytorch.constraints import Interval

from sklearn.preprocessing import StandardScaler


ACTIVATIONS = {
    "relu": torch.nn.ReLU,
    "tanh": torch.nn.Tanh,
    "lrelu": torch.nn.LeakyReLU,
}     

parser = argparse.ArgumentParser()
parser.add_argument(
    "--problem",
    choices=["redox-mer", "solvation", "kinase", "laser", "pce", "photoswitch", "ampc", "d4"],
    default="redox-mer",
)
parser.add_argument("--method", choices=["laplace", "fsplaplace"], default="laplace")
parser.add_argument(
    "--foundation_model",
    default="molformer",
    choices=[
        "molformer",
        "roberta-large",
        "t5-base",
        "t5-base-chem",
        "gpt2-medium",
        "gpt2-large",
        "llama-2-7b",
    ],
)
parser.add_argument(
    "--prompt_type",
    choices=["single-number", "just-smiles", "completion"],
    default="just-smiles",
)
parser.add_argument(
    "--laplace_type", choices=["last_layer", "all_layer"], default="all_layer"
)
parser.add_argument(
    "--kernel",
    choices=[
        "tanimoto",
        "matern_0.5",
        "matern_1.5",
        "matern_2.5",
        "rbf",
        "linear",
    ],
    default="matern_2.5",
)
parser.add_argument(
    "--activation_fn", choices=["relu", "lrelu", "tanh"], default="relu"
)
parser.add_argument(
    "--noise_var", default=0.0001
)
parser.add_argument(
    "--wd", default=5e-4
)
parser.add_argument(
    "--context_point_feature_type", 
    choices=["fingerprints", "molformer", "t5-base-chem", "mordred", "degree_of_conjugation", "force_field", "dft"],
    default="molformer"
)
parser.add_argument(
    "--lr", default=5e-5
)
parser.add_argument(
    "--map_context_points", choices=["uniform", "bo_candidates"], default="bo_candidates"
)
parser.add_argument(
    "--cov_context_points", choices=["sobol", "bo_candidates"], default="bo_candidates"
)
parser.add_argument("--lora_alpha_factor", default=1.)
parser.add_argument("--lora_dropout", default=0.1)
parser.add_argument("--lr_lora", default=1e-4)
parser.add_argument("--lora_r", default=8)
parser.add_argument("--acqf", choices=["ei", "ucb", "ts"], default="ts")
parser.add_argument("--n_init_data", type=int, default=10)
parser.add_argument("--exp_len", type=int, default=200)
parser.add_argument("--randseed", type=int, default=1)
parser.add_argument("--dtype", choices=["float32", "float64"], default="float64")
args = parser.parse_args()

args.lr = float(args.lr)
args.wd = float(args.wd)
args.noise_var = float(args.noise_var)
args.lora_alpha_factor = float(args.lora_alpha_factor)
args.lora_r = int(args.lora_r)
args.lora_dropout = float(args.lora_dropout)
args.lr_lora = float(args.lr_lora)

if args.method == "laplace" and args.dtype == "float64":
    print("Changing to float32", flush=True)
    args.dtype = "float32"

# Molformer expects only SMILES
if args.foundation_model == "molformer":
    args.prompt_type = "just-smiles"


device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    print("Running on GPU")
else:
    print("Running on CPU")

np.random.seed(args.randseed)
torch.manual_seed(args.randseed)

if args.foundation_model == "molformer":
    tokenizer = get_molformer_tokenizer()
elif "roberta" in args.foundation_model:
    tokenizer = get_roberta_tokenizer(args.foundation_model)
elif "t5" in args.foundation_model:
    if "chem" in args.foundation_model:
        foundation_model_real = "GT4SD/multitask-text-and-chemistry-t5-base-augm"
    else:
        foundation_model_real = args.foundation_model
    tokenizer = get_t5_tokenizer(foundation_model_real)
elif "gpt2" in args.foundation_model:
    tokenizer = get_gpt2_tokenizer(args.foundation_model)
elif "llama-2" in args.foundation_model:
    tokenizer = get_llama2_tokenizer(args.foundation_model)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    # print(tokenizer.pad_token, tokenizer.pad_token_id, tokenizer.eos_token_id)

if args.problem == "redox-mer":
    dataset = pd.read_csv("data/redox_mer_with_iupac.csv.gz")
    OBJ_COL = "Ered" 
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = RedoxDataProcessor(prompt_builder, tokenizer)
elif args.problem == "solvation":
    dataset = pd.read_csv("data/redox_mer_with_iupac.csv.gz")
    OBJ_COL = "Gsol" 
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = SolvationDataProcessor(prompt_builder, tokenizer)
elif args.problem == "kinase":
    dataset = pd.read_csv("data/enamine10k.csv.gz")
    OBJ_COL = "score"  
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = KinaseDockingDataProcessor(prompt_builder, tokenizer)
elif args.problem == "laser":
    dataset = pd.read_csv("data/laser_multi10k.csv.gz")
    OBJ_COL = "Fluorescence Oscillator Strength"
    MAXIMIZATION = True
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = LaserEmitterDataProcessor(prompt_builder, tokenizer)
elif args.problem == "pce":
    dataset = pd.read_csv("data/photovoltaics_pce10k.csv.gz")
    OBJ_COL = "pce"
    MAXIMIZATION = True
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = PhotovoltaicsPCEDataProcessor(prompt_builder, tokenizer)
elif args.problem == "photoswitch":
    dataset = pd.read_csv("data/photoswitches.csv.gz")
    OBJ_COL = "Pi-Pi* Transition Wavelength"
    MAXIMIZATION = True
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = PhotoswitchDataProcessor(prompt_builder, tokenizer)
elif args.problem == "ampc":
    dataset = pd.read_csv("data/Zinc_AmpC_Docking_filtered.csv.gz")
    OBJ_COL = "dockscore"  
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = AmpCDockingDataProcessor(prompt_builder, tokenizer)
elif args.problem == "d4":
    dataset = pd.read_csv("data/Zinc_D4_Docking_filtered.csv.gz")
    OBJ_COL = "dockscore"        
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = D4DockingDataProcessor(prompt_builder, tokenizer)
else: 
    print("Invalid test function!")
    sys.exit(1)

if args.method == "fsplaplace":
    CACHE_PATH = f"data/cache/{args.problem}/"
    features = torch.load(CACHE_PATH + f"{args.context_point_feature_type}_feats.bin")
    features = [feature.numpy().reshape(-1) for feature in features]
    if args.context_point_feature_type != "fingerprints": 
        print("Normalizing features", flush=True)
        feature_dim = features[0].shape[-1]
        features = list(
            StandardScaler().fit_transform(features)
        )
    dataset["features"] = features

# Turn into a maximization problem if necessary
if not MAXIMIZATION:
    dataset[OBJ_COL] = -dataset[OBJ_COL]
ground_truth_max = dataset[OBJ_COL].max()

print()
if args.method == "laplace":
    print(
        f"Test Function: {args.problem}; Foundation Model: {args.foundation_model}; Prompt Type: {args.prompt_type}; Randseed: {args.randseed}; wd: {args.wd}; activation: {args.activation_fn}; lr: {args.lr} - noise_var: {args.noise_var} - lr_lora {args.lr_lora} - lora_alpha_factor: {args.lora_alpha_factor} - lora_dropout: {args.lora_dropout} - lora_r: {args.lora_r};", flush=True
    )
elif args.method == "fsplaplace":
    print(
        f"Test Function: {args.problem}; Foundation Model: {args.foundation_model}; Prompt Type: {args.prompt_type}; Randseed: {args.randseed}; kernel: {args.kernel}; activation: {args.activation_fn}; context_point_feature_type: {args.context_point_feature_type}; map_context_points: {args.map_context_points}; cov_context_points: {args.cov_context_points}; lr: {args.lr} - noise_var: {args.noise_var} - lr_lora {args.lr_lora} - lora_alpha_factor: {args.lora_alpha_factor} - lora_dropout: {args.lora_dropout} - lora_r: {args.lora_r};", flush=True
    )
print(
    "---------------------------------------------------------------------------------------------------------------"
)
print()

dataset_train = []
while len(dataset_train) < args.n_init_data:
    idx = np.random.randint(len(dataset))
    # Make sure that the optimum is not included
    if dataset.loc[idx][OBJ_COL] >= ground_truth_max:
        continue
    dataset_train.append(helpers.pop_df(dataset, idx))


### Add activation function config
def get_model():
    if args.foundation_model == "molformer":
        model = MolFormerRegressor(tokenizer, activation_fn=ACTIVATIONS[args.activation_fn], dtype=args.dtype)
        target_modules = ["query", "value"]
    elif "roberta" in args.foundation_model:
        model = RobertaRegressor(
            kind=args.foundation_model,
            tokenizer=tokenizer,
            reduction=LLMFeatureType.AVERAGE,
            activation_fn=ACTIVATIONS[args.activation_fn]
        )
        target_modules = ["query", "value"]
    elif "gpt2" in args.foundation_model:
        model = GPT2Regressor(
            kind=args.foundation_model,
            tokenizer=tokenizer,
            reduction=LLMFeatureType.AVERAGE,
            activation_fn=ACTIVATIONS[args.activation_fn]
        )
        target_modules = ["c_attn"]
    elif "llama-2" in args.foundation_model:
        model = Llama2Regressor(
            kind=args.foundation_model,
            tokenizer=tokenizer,
            reduction=LLMFeatureType.AVERAGE,
            activation_fn=ACTIVATIONS[args.activation_fn]
        )
        target_modules = ["q_proj", "v_proj"]
    elif "t5" in args.foundation_model:
        if "chem" in args.foundation_model:
            model = T5Regressor(
                kind="GT4SD/multitask-text-and-chemistry-t5-base-augm",
                tokenizer=tokenizer,
                reduction=LLMFeatureType.AVERAGE,
                activation_fn=ACTIVATIONS[args.activation_fn]
            )
        else:
            model = T5Regressor(
                kind=args.foundation_model,
                tokenizer=tokenizer,
                reduction=LLMFeatureType.AVERAGE,
                activation_fn=ACTIVATIONS[args.activation_fn]
            )
        target_modules = ["q", "v"]
    else:
        raise NotImplementedError

    config = LoraConfig(
        r=args.lora_r,
        lora_alpha=int(args.lora_alpha_factor * args.lora_r),
        target_modules=target_modules,
        lora_dropout=args.lora_dropout,
        bias="none",
        modules_to_save=["head"],
    )
    lora_model = get_peft_model(model, config)
    for p in lora_model.base_model.head.original_module.parameters():
        p.requires_grad = False
    # for n, p in lora_model.named_parameters():
    #     if p.requires_grad:
    #         print(n)
    return lora_model


# Train + Laplace
if args.method == "laplace":
    if args.laplace_type == "all_layer":
        config = LoraLaplaceConfig(
            noise_var=args.noise_var,
            wd=args.wd,
            lr=args.lr,
            lr_lora=args.lr_lora,
            hess_factorization="kron",
            subset_of_weights="all",
            marglik_mode="posthoc",
            prior_prec_structure="layerwise",
        )
    else:
        config = LoraLaplaceConfig(
            noise_var=args.noise_var,
            wd=args.wd,
            lr=args.lr,
            lr_lora=args.lr_lora,
            marglik_mode="posthoc",
            prior_prec_structure="layerwise",
            hess_factorization="full",
            subset_of_weights="last_layer",
        )
elif args.method == "fsplaplace":
    config = LoraFSPLaplaceConfig(
        lr=args.lr,
        lr_lora=args.lr_lora,
        noise_var=args.noise_var,
        cov_context_points=args.cov_context_points, 
        map_context_points=args.map_context_points, 
    )
# else:
#     raise NotImplementedError

# if args.problem == "photoswitch":
#     config.lr = 1e-2
#     config.lr_lora = 3e-3

APPEND_EOS = args.foundation_model != "molformer" and (
    "t5" not in args.foundation_model
)

if args.method == "laplace":    
    model = LLALoRALLMBayesOpt(
        get_model,
        dataset_train,
        data_processor,
        dtype=args.dtype,
        device=device,
        laplace_config=config,
        append_eos=APPEND_EOS,
    )
elif args.method == "fsplaplace":
    mean_fn = ZeroMean(batch_shape=torch.Size([1]))
    if "matern" in args.kernel:
        nu = float(args.kernel.split("_")[-1])
        kernel_fn = MaternKernel(nu=nu, ard_num_dims=feature_dim, lengthscale_constraint=Interval(1e-3, 1e3))
        kernel_fn.lengthscale= math.sqrt(feature_dim)
    elif args.kernel == "rbf":
        kernel_fn = RBFKernel(ard_num_dims=feature_dim, lengthscale_constraint=Interval(1e-3, 1e3))
        kernel_fn.lengthscale= math.sqrt(feature_dim)
    elif args.kernel == "tanimoto":
        kernel_fn = TanimotoKernel()
    elif args.kernel == "linear":
        kernel_fn = LinearKernel(batch_shape=torch.Size([1]))
    else: 
        raise ValueError(f"Invalid kernel: {args.kernel}")
    
    if args.kernel != "linear":
        kernel_fn = ScaleKernel(kernel_fn, batch_shape=torch.Size([1]))

    model = FSPLLALoRALLMBayesOpt(
        get_model=get_model,
        training_set=dataset_train,
        data_processor=data_processor,
        candidate_X=dataset,
        mean_fn=mean_fn,
        kernel_fn=kernel_fn,
        laplace_config=config,
        dtype=args.dtype,
        device=device,
        append_eos=APPEND_EOS,
    )


best_y = pd.DataFrame(dataset_train)[OBJ_COL].max()
pbar = tqdm.trange(args.exp_len, position=0, colour="green", file=sys.stdout, disable=True)
print(
    f"[Best f(x) = {helpers.y_transform(best_y, MAXIMIZATION):.3f}]", flush=True
)

trace_best_y = [helpers.y_transform(ground_truth_max, MAXIMIZATION)] * (
    args.exp_len + 1
)
trace_timing = [0.0] * (args.exp_len + 1)
trace_acqvals = [-math.inf] * (args.exp_len + 1)

timing_train = []
timing_preds = []

for i in pbar:
    # Timing
    # start = torch.cuda.Event(enable_timing=True)
    # end = torch.cuda.Event(enable_timing=True)
    # torch.cuda.synchronize()
    # start.record()

    # BO iteration
    # Laplace - molformer: 15 is max for 2080gtx
    # FSPLaplace 64 works 
    #batch_size = 15 if args.method == "laplace" else 256 # 256 works, 254, 320 fails
    ################################################################################################################################
    if device == "cuda":
        gpu_name = torch.cuda.get_device_name(0).lower()
        print("gpu name", gpu_name)
        if args.method == "fsplaplace":
            batch_size = 1024 if "h100" in gpu_name else 256 # 4096 work on h100
        else:
            batch_size = 32 if "h100" in gpu_name else 15 # 256 works, 254, 320 fails, 32
    else:
        batch_size = 32
    dataloader = data_processor.get_dataloader(
        dataset, batch_size=batch_size, shuffle=False, append_eos=APPEND_EOS  # 15 OOM
    )

    preds, uncerts, labels = [], [], []
    acq_vals = []
    sub_pbar = tqdm.tqdm(
        dataloader,
        position=1,
        colour="blue",
        desc="[Prediction over dataset]",
        leave=False,
        file=sys.stdout,
        disable=True
    )

    # start_pred = torch.cuda.Event(enable_timing=True)
    # end_pred = torch.cuda.Event(enable_timing=True)
    # torch.cuda.synchronize()
    # start_pred.record()

    for data in sub_pbar:
        posterior = model.posterior(data)
        f_mean, f_var = posterior.mean, posterior.variance
        if args.acqf == "ei":
            acq_vals.append(ei(f_mean, f_var, best_y))
        elif args.acqf == "ucb":
            acq_vals.append(ucb(f_mean, f_var))
        else:
            acq_vals.append(thompson_sampling(f_mean, f_var))

        preds.append(f_mean)
        uncerts.append(f_var.sqrt())
        labels.append(data["labels"])

    # end_pred.record()
    # torch.cuda.synchronize()
    # timing_preds.append(start_pred.elapsed_time(end_pred) / 1000)

    acq_vals = torch.cat(acq_vals, dim=0).cpu().squeeze()
    preds, uncerts, labels = (
        torch.cat(preds, dim=0).cpu(),
        torch.cat(uncerts, dim=0).cpu(),
        torch.cat(labels, dim=0),
    )
    test_loss = torch.nn.MSELoss()(preds, labels).item()

    _, idx = acq_vals.topk(k=10)
    for l, p, u, a in zip(labels[idx], preds[idx], uncerts[idx], acq_vals[idx]):
        print(
            f"True: {l.item():.3f}, Mean: {p.item():.3f}, Std: {u.item():.3f}, Acqf: {a.item():.3f}", flush=True
        )
    # input()

    # Pick a molecule (a row in the current dataset) that maximizes the acquisition
    idx_best = torch.argmax(acq_vals).item()
    new_data = helpers.pop_df(dataset, idx_best)

    # Update the current best y
    if new_data[OBJ_COL] > best_y:
        best_y = new_data[OBJ_COL]
        print(best_y)

    # Early stopping if we already got the max
    if best_y >= ground_truth_max:
        break

    # start_train = torch.cuda.Event(enable_timing=True)
    # end_train = torch.cuda.Event(enable_timing=True)
    # torch.cuda.synchronize()
    # start_train.record()

    # Update surrogate
    model = model.condition_on_observations(new_data)

    # end_train.record()
    # torch.cuda.synchronize()
    # timing_train.append(start_train.elapsed_time(end_train) / 1000)

    print(
        f"{i}/{args.exp_len} - [Best f(x) = {helpers.y_transform(best_y, MAXIMIZATION):.3f}, "
        + f"curr f(x) = {helpers.y_transform(new_data[OBJ_COL], MAXIMIZATION):.3f}, "
        + f"test MSE: {test_loss:.3f}]", 
        flush=True
    )

    # Save results
    # end.record()
    # torch.cuda.synchronize()
    # timing = start.elapsed_time(end) / 1000
    # trace_best_y[i + 1] = helpers.y_transform(best_y, MAXIMIZATION)
    # trace_timing[i + 1] = timing

# print('Train time (avg & sem)', f'{np.mean(timing_train):.1f}', f'{st.sem(timing_train):.1f}')
# print('Preds time (avg & sem)', f'{np.mean(timing_preds):.1f}', f'{st.sem(timing_preds):.1f}')

# Save results
path = f"results/param_sensitivity/{args.problem}/{args.method}_{args.foundation_model}_finetuning/"
if not os.path.exists(path):
    os.makedirs(path)

suffix = f"{args.lora_alpha_factor}_{args.lora_r}_{args.lora_dropout}_{args.lr_lora}"
if args.method == "fsplaplace":
    suffix += f"_{args.kernel}_{args.activation_fn}_{args.map_context_points}_{args.cov_context_points}_{args.lr}_{args.noise_var}"
elif args.method == "laplace":
    suffix += f"_{args.wd}_{args.activation_fn}_{args.lr}_{args.noise_var}"

np.save(
    f"{path}/timing_train_{args.n_init_data}_{args.acqf}_{args.laplace_type}_{args.randseed}_{suffix}.npy",
    timing_train,
)
np.save(
    f"{path}/timing_preds_{args.n_init_data}_{args.acqf}_{args.laplace_type}_{args.randseed}_{suffix}.npy",
    timing_preds,
)

np.save(
    f"{path}/trace_acqvals_{args.n_init_data}_{args.acqf}_{args.laplace_type}_{args.randseed}_{suffix}.npy",
    trace_acqvals,
)

if args.foundation_model == "molformer":
    np.save(
        f"{path}/trace_best_y_{args.n_init_data}_{args.acqf}_{args.laplace_type}_{args.randseed}_{suffix}.npy",
        trace_best_y,
    )
    np.save(
        f"{path}/trace_timing_{args.n_init_data}_{args.acqf}_{args.laplace_type}_{args.randseed}_{suffix}.npy",
        trace_timing,
    )
else:
    np.save(
        f"{path}/{args.prompt_type}_trace_best_y_{args.n_init_data}_{args.acqf}_{args.laplace_type}_{args.randseed}_{suffix}.npy",
        trace_best_y,
    )
    np.save(
        f"{path}/{args.prompt_type}_trace_timing_{args.n_init_data}_{args.acqf}_{args.laplace_type}_{args.randseed}_{suffix}.npy",
        trace_timing,
    )
