import argparse
import json
import pandas as pd
from downstream.downstream_model import DrugPropertyModelPooling
import torch
import numpy as np
import random
import os
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from model.tokenizer_smiles import SmilesTokenizer
import model.smilesgraph
from downstream.utils import evaluation, kfold_indices, sort_df_with_smiles_by_scaffold
import optuna
import copy
import math


def early_stopping_callback(study, trial):
    global best_score, no_improvement_count
    current_score = trial.value

    # Check for improvement
    if current_score < best_score and best_score is not None and current_score is not None:
        best_score = current_score
        no_improvement_count = 0  # Reset the no improvement count
    else:
        no_improvement_count += 1

    # Early termination condition
    if no_improvement_count >= args.optuna_patience:
        print("Early termination")
        study.stop()


class DrugDataset(Dataset):
    def __init__(self, data_file, tokenizer, max_length, df=None, data=None, is_log=False):
        if data is None:
            if df is None:
                print("Read data from", data_file)
                df = pd.read_csv(data_file)
                print("Data shape", df.shape)
            self.data = [(r['Drug'], r['Y']) for index, r in df.iterrows()]
        else:
            self.data = data
        self.is_log = is_log
        self.tokenizer = tokenizer
        self.max_length = max_length
        if self.is_log:
            self.data = [(d[0], math.log(d[1])) for d in self.data]

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        d, y = self.data[index]
        d = self.tokenizer.batch_encode([d], max_len=self.max_length)
        return d["input_ids"].squeeze(0), d["attention_mask"].squeeze(0), y

    def split_to_folds(self, split_type, number_folds=5):
        if split_type == "random":
            fold_indices = kfold_indices(self.data, n_splits=number_folds, is_shuffle=True)
        elif split_type == "preserved_order":
            fold_indices = kfold_indices(self.data, n_splits=number_folds, is_shuffle=False)
        elif split_type == "scaffold":
            self.data = sort_df_with_smiles_by_scaffold(self.data)
            fold_indices = kfold_indices(self.data, n_splits=number_folds, is_shuffle=False)
        else:
            print("Error: unknown split type", split_type)
            return None
        folds = []
        for train_index, test_index in fold_indices:
            train_data = [self.data[idx] for idx in train_index]
            test_data = [self.data[idx] for idx in test_index]
            train_data = DrugDataset(None, self.tokenizer, self.max_length, data=train_data)
            test_data = DrugDataset(None, self.tokenizer, self.max_length, data=test_data)
            folds.append((train_data, test_data))
        return folds


def train(trial, stats, folds):
    set_seed(args.seeds)
    # search space
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-4, log=True)
    batch_size = trial.suggest_int('batch_size', 2, 8)
    batch_size = 2 ** batch_size
    hidden_size = trial.suggest_int('hidden_size', 5, 10)
    hidden_size = 2 ** hidden_size
    num_hidden_layers = trial.suggest_int('num_hidden_layers', 1, 4)
    test_data = args.data_folder + "/" + args.dataset.lower() + "/test.csv"
    test_data = DrugDataset(test_data, tokenizer, args.max_length)
    test_data = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    best_vals = []
    best_tests = []
    best_predictions = []
    fold_idx = 0
    for train_data, val_data in folds:
        train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_data = DataLoader(val_data, batch_size=batch_size, shuffle=False)
        net = DrugPropertyModelPooling(drug_size=args.drug_embed_size,
                                       model=copy.deepcopy(ant),
                                       hidden_size=hidden_size,
                                       num_hidden_layers=num_hidden_layers,
                                       task_type=args.task_type,
                                       num_class=args.num_class).to(device)

        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
        moving_loss = 0.0
        n = 0
        run = True
        best_val = None
        best_test = None
        tolerance_count = 0
        while run:
            for input_ids, attention_mask, labels in train_data:
                net.zero_grad()
                p = net(input_ids.to(device), attention_mask.to(device))
                if args.task_type == 'regression':
                    if args.metrics == 'mae' or args.metrics == 'spearman':
                        loss = torch.mean(torch.abs(p.squeeze(-1) - labels.to(device)))
                    else:
                        loss = torch.nn.MSELoss()(p.squeeze(-1), labels.to(dtype=torch.float32).to(device))
                else:
                    loss = torch.nn.CrossEntropyLoss()(p.squeeze(-1), labels.to(torch.int64).to(device))
                loss.backward()
                optimizer.step()
                moving_loss += loss.cpu().item()
                if n % args.val_steps == 0:
                    print("Step {}; Moving loss {}".format(n, moving_loss / (n + 1)))
                n += 1
                # test
                if n % args.val_steps == 0:
                    all_p_test = []
                    all_label_test = []
                    pbar = tqdm(test_data)
                    with torch.no_grad():
                        for input_ids, attention_mask, labels in pbar:
                            p = net(input_ids.to(device), attention_mask.to(device))
                            all_p_test.append(p.squeeze(-1).detach().cpu())
                            all_label_test.append(labels.cpu())
                        all_p_test = torch.cat(all_p_test)
                        if args.metrics == "spearman":
                            all_p_test = torch.exp(all_p_test)
                        all_label_test = torch.cat(all_label_test)
                        test_metrics, all_p_test, all_label_test = evaluation(all_p_test, all_label_test, args.metrics)
                        print(n, "Test metrics:", test_metrics)

                    all_p = []
                    all_label = []
                    pbar = tqdm(val_data)
                    with torch.no_grad():
                        for input_ids, attention_mask, labels in pbar:
                            p = net(input_ids.to(device), attention_mask.to(device))
                            all_p.append(p.squeeze(-1).detach().cpu())
                            all_label.append(labels.cpu())
                        all_p = torch.cat(all_p)
                        all_label = torch.cat(all_label)
                        if args.metrics == "spearman":
                            all_p = torch.exp(all_p)
                        val_metrics, all_p, all_label = evaluation(all_p, all_label, args.metrics)
                        print(n, "Val metrics:", val_metrics)

                        if best_val is None:
                            best_val = val_metrics
                            best_test = test_metrics
                            best_prediction = all_p_test
                            best_label = all_label_test
                        elif best_val > val_metrics:
                            print("Val score improve:", best_val,
                                  val_metrics, "Test score",
                                  test_metrics)
                            best_test = test_metrics
                            if abs(best_val - val_metrics) / (abs(best_val) + 0.000001) > args.val_improve_tolerance:
                                tolerance_count = 0
                            else:
                                print("Improvement is not significant",
                                      abs(best_val - val_metrics) / (abs(best_val) + 0.000001),
                                      args.val_improve_tolerance,
                                      best_val,
                                      val_metrics)
                                tolerance_count += 1
                            best_val = val_metrics
                            best_prediction = all_p_test
                            best_label = all_label_test
                        else:
                            tolerance_count += 1
                        if tolerance_count > args.val_tolerance:
                            print("Early termination")
                            run = False
                            break

                if n == args.steps:
                    run = False
                    break
        print("Best score", best_test, best_val, fold_idx)
        best_tests.append(best_test)
        best_vals.append(best_val)
        best_predictions.append(best_prediction)
        fold_idx += 1
    best_val = torch.mean(torch.tensor(best_vals)).cpu().item()
    best_test = torch.mean(torch.tensor(best_tests)).cpu().item()
    print("Best score all", best_test, best_val)
    if stats["best_val"] is None or stats["best_val"] > best_val:
        stats["best_val"] = best_val
        stats["best_test"] = best_test
        stats["best_label"] = best_label
        for i in range(len(best_predictions)):
            with open(args.output_folder + "/" + args.dataset + "_prediction_" + str(i) + '_' +
                      str(args.seeds) + ".json", "wt") as f:
                prediction = {"prediction": best_predictions[i].cpu().tolist(),
                              "label": stats["best_label"].cpu().tolist()}
                json.dump(prediction, f, indent=4)
        stats_output = {"best_test": stats["best_test"], "best_val": stats["best_val"]}
        print("Best_val", stats_output)
    return best_val


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='TDC Admet training')
    parser.add_argument('--dataset', default='vdss_lombardo', type=str,
                        help='Name of the ADMET dataset')
    parser.add_argument("--mask_ratio", default=-1.0, type=float,
                        help='Random masking sequence.')
    parser.add_argument("--drug_embed_size", default=768, type=int,
                        help='Drug embedding size')
    parser.add_argument("--steps", default=10000, type=int,
                        help='Maximum number of training steps')
    parser.add_argument("--seeds", default=0, type=int,
                        help='Random seeds.')
    parser.add_argument("--device", default="cpu", type=str,
                        help='Device: cpu vs cuda.')
    parser.add_argument("--task_type", default="regression", type=str,
                        help='Task type: regression, classification')
    parser.add_argument("--metrics", default="spearman", type=str,
                        help='Report metrics')
    parser.add_argument("--num_class", default=2, type=int,
                        help='Number of classes')

    parser.add_argument("--val_steps", default=10, type=int,
                        help='Val steps.')
    parser.add_argument("--max_length", default=128, type=int,
                        help='Max length')
    parser.add_argument('--checkpoint', default='../data/checkpoint-1600000/pytorch_model.bin', type=str,
                        help='AntBrain checkpoint')
    parser.add_argument('--vocab', default='../data/vocab.txt', type=str,
                        help='Vocab file for the smiles tokenizer')
    parser.add_argument('--data_folder', default='data/', type=str,
                        help='Data folder')
    parser.add_argument("--val_tolerance", default=3, type=int,
                        help='Val tolerance.')
    parser.add_argument("--val_improve_tolerance", default=0.01, type=float,
                        help='Val tolerance.')
    parser.add_argument("--output_folder", default="./", type=str,
                        help='Output folder')
    parser.add_argument("--optuna_patience", default=25, type=int,
                        help='The number of trails of no improvement before terminating the optuna process.')
    parser.add_argument("--split_type", default="random", type=str,
                        help='Type of splitting the train_val into folds: random, preserved_order, scaffold')
    parser.add_argument("--hash_size", default=256, type=int)

    args = parser.parse_args()
    print(args)
    # Initialize variables
    best_score = float('inf')
    no_improvement_count = 0
    set_seed(args.seeds)
    device = torch.device(args.device)
    tokenizer = SmilesTokenizer(vocab_file=args.vocab)
    print(args)
    checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu'))
    ant = model.smilesgraph.AntBrain(model_name='bert-base-uncased', hash_size=args.hash_size)
    ant.load_state_dict(checkpoint)

    train_data = args.data_folder + "/" + args.dataset.lower() + "/train_val.csv"
    is_log = False
    if args.metrics == "spearman":
        is_log = True
    train_data = DrugDataset(train_data, tokenizer, args.max_length, is_log=is_log)
    folds = train_data.split_to_folds(args.split_type)

    stats = {"best_test": None, "best_val": None, "best_prediction": None, "best_label": None}
    # Create a study object and optimize the objective function
    # pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=10, interval_steps=1)
    study = optuna.create_study(direction='minimize')
    study.optimize(lambda trial: train(trial, stats, folds), n_trials=100, callbacks=[early_stopping_callback])

    # Print the best hyperparameters and the corresponding accuracy
    best_params = study.best_params
    best_accuracy = study.best_value
    print("Best Hyperparameters:", best_params)
    print("Best Val:", best_accuracy)
    stats_output = {"best_test": stats["best_test"], "best_val": stats["best_val"]}
    print("Best stats", stats_output)
