import argparse
import json
import math

import pandas as pd
from downstream.downstream_model import DrugPropertyModelPooling
import torch
import numpy as np
import random
import os
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from model.tokenizer import SmilesTokenizer
from model.antbrain import AntBrain
from downstream.utils import evaluation, kfold_indices, sort_df_with_smiles_by_scaffold
import optuna
import copy


def early_stopping_callback(study, trial):
    global best_score, no_improvement_count
    current_score = trial.value

    # Check for improvement
    if current_score < best_score and best_score is not None and current_score is not None:
        best_score = current_score
        no_improvement_count = 0  # Reset the no improvement count
    else:
        no_improvement_count += 1

    # Early termination condition
    if no_improvement_count >= args.optuna_patience:
        print("Early termination")
        study.stop()


class DrugDataset(Dataset):
    def __init__(self, data_file, tokenizer, max_length, df=None, data=None, is_log=False):
        if data is None:
            if df is None:
                print("Read data from", data_file)
                df = pd.read_csv(data_file)
                print("Data shape", df.shape)
            self.data = [(r['Drug'], r['Y']) for index, r in df.iterrows()]
        else:
            self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_log = is_log
        if self.is_log:
            self.data = [(d[0], math.log(d[1])) for d in self.data]

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        d, y = self.data[index]
        d = self.tokenizer.batch_encode_plus([d], max_length=self.max_length, padding='max_length',
                                             return_tensors='pt', truncation=True)
        return d.input_ids.squeeze(0), d.attention_mask.squeeze(0), y

    def merge(self, dataset):
        self.data += dataset.data

    def split_to_folds(self, split_type, number_folds=5):
        if split_type == "random":
            fold_indices = kfold_indices(self.data, n_splits=number_folds, is_shuffle=True)
        elif split_type == "preserved_order":
            fold_indices = kfold_indices(self.data, n_splits=number_folds, is_shuffle=False)
        elif split_type == "scaffold":
            self.data = sort_df_with_smiles_by_scaffold(self.data)
            fold_indices = kfold_indices(self.data, n_splits=number_folds, is_shuffle=False)
        else:
            print("Error: unknown split type", split_type)
            return None
        folds = []
        for train_index, test_index in fold_indices:
            train_data = [self.data[idx] for idx in train_index]
            test_data = [self.data[idx] for idx in test_index]
            train_data = DrugDataset(None, self.tokenizer, self.max_length, data=train_data)
            test_data = DrugDataset(None, self.tokenizer, self.max_length, data=test_data)
            folds.append((train_data, test_data))
        return folds


def train(trial, stats, folds):
    set_seed(args.seeds)
    # search space
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-4, log=True)
    batch_size = trial.suggest_int('batch_size', 2, 7)
    batch_size = 2 ** batch_size
    hidden_size = trial.suggest_int('hidden_size', 5, 10)
    hidden_size = 2 ** hidden_size
    num_hidden_layers = trial.suggest_int('num_hidden_layers', 1, 4)
    test_data = args.data_folder + "/" + args.dataset.lower() + "/test.csv"
    test_data = DrugDataset(test_data, tokenizer, args.max_length)
    test_data = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    best_vals = []
    best_tests = []
    best_predictions = []
    fold_idx = 0
    for train_data, val_data in folds:
        train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_data = DataLoader(val_data, batch_size=batch_size, shuffle=False)
        net = DrugPropertyModelPooling(drug_size=args.drug_embed_size,
                                       model=copy.deepcopy(ant),
                                       hidden_size=hidden_size,
                                       num_hidden_layers=num_hidden_layers,
                                       task_type=args.task_type,
                                       num_class=args.num_class).to(device)

        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
        moving_loss = 0.0
        n = 0
        run = True
        best_val = None
        best_test = None
        tolerance_count = 0
        while run:
            for input_ids, attention_mask, labels in train_data:
                net.zero_grad()
                p = net(input_ids.to(device), attention_mask.to(device))
                if args.task_type == 'regression':
                    if args.metrics == 'mae' or args.metrics == 'spearman':
                        loss = torch.mean(torch.abs(p.squeeze(-1) - labels.to(device)))
                    else:
                        loss = torch.nn.MSELoss()(p.squeeze(-1), labels.to(dtype=torch.float32).to(device))
                else:
                    loss = torch.nn.CrossEntropyLoss()(p.squeeze(-1), labels.to(torch.int64).to(device))
                loss.backward()
                optimizer.step()
                moving_loss += loss.cpu().item()
                if n % args.val_steps == 0:
                    print("Step {}; Moving loss {}".format(n, moving_loss / (n + 1)))
                n += 1
                # test
                if n % args.val_steps == 0:
                    all_p_test = []
                    all_label_test = []
                    pbar = tqdm(test_data)
                    with torch.no_grad():
                        for input_ids, attention_mask, labels in pbar:
                            p = net(input_ids.to(device), attention_mask.to(device))
                            all_p_test.append(p.squeeze(-1).detach().cpu())
                            all_label_test.append(labels.cpu())
                        all_p_test = torch.cat(all_p_test)
                        if args.metrics == 'spearman':
                            all_p_test = torch.exp(all_p_test)
                        all_label_test = torch.cat(all_label_test)
                        test_metrics, all_p_test, all_label_test = evaluation(all_p_test, all_label_test, args.metrics)
                        print(n, "Test metrics:", test_metrics)

                    all_p = []
                    all_label = []
                    pbar = tqdm(val_data)
                    with torch.no_grad():
                        for input_ids, attention_mask, labels in pbar:
                            p = net(input_ids.to(device), attention_mask.to(device))
                            all_p.append(p.squeeze(-1).detach().cpu())
                            all_label.append(labels.cpu())
                        all_p = torch.cat(all_p)
                        if args.metrics == 'spearman':
                            all_p = torch.exp(all_p)
                        all_label = torch.cat(all_label)
                        val_metrics, all_p, all_label = evaluation(all_p, all_label, args.metrics)
                        print(n, "Val metrics:", val_metrics)

                        if best_val is None:
                            best_val = val_metrics
                            best_test = test_metrics
                            best_prediction = all_p_test
                            best_label = all_label_test
                        elif best_val > val_metrics:
                            print("Val score improve:", best_val,
                                  val_metrics, "Test score",
                                  test_metrics)
                            best_test = test_metrics
                            if abs(best_val - val_metrics) / (abs(best_val) + 0.000001) > args.val_improve_tolerance:
                                tolerance_count = 0
                            else:
                                print("Improvement is not significant",
                                      abs(best_val - val_metrics) / (abs(best_val) + 0.000001),
                                      args.val_improve_tolerance,
                                      best_val,
                                      val_metrics)
                                tolerance_count += 1
                            best_val = val_metrics
                            best_prediction = all_p_test
                            best_label = all_label_test
                        else:
                            tolerance_count += 1
                        if tolerance_count > args.val_tolerance:
                            print("Early termination")
                            run = False
                            break

                if n == args.steps:
                    run = False
                    break
        print("Best score", best_test, best_val, fold_idx)
        best_tests.append(best_test)
        best_vals.append(best_val)
        best_predictions.append(best_prediction)
        fold_idx += 1
    best_val = torch.mean(torch.tensor(best_vals)).cpu().item()
    best_test = torch.mean(torch.tensor(best_tests)).cpu().item()
    print("Best score all", best_test, best_val)
    if stats["best_val"] is None or stats["best_val"] > best_val:
        stats["best_val"] = best_val
        stats["best_test"] = best_test
        stats["best_label"] = best_label
        for i in range(len(best_predictions)):
            with open(args.output_folder + "/" + args.dataset + "_prediction_" + str(i) + '_' +
                      str(args.seeds) + ".json", "wt") as f:
                prediction = {"prediction": best_predictions[i].cpu().tolist(),
                              "label": stats["best_label"].cpu().tolist()}
                json.dump(prediction, f, indent=4)
        stats_output = {"best_test": stats["best_test"], "best_val": stats["best_val"]}
        print("Best_val", stats_output)
    return best_val


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='TDC Admet training')
    parser.add_argument('--dataset', default='DILI', type=str,
                        help='Name of the ADMET dataset')
    parser.add_argument("--lr", default=5e-06, type=float,
                        help='Learning rate.')
    parser.add_argument("--mask_ratio", default=-1.0, type=float,
                        help='Random masking sequence.')
    parser.add_argument("--drug_embed_size", default=768, type=int,
                        help='Drug embedding size')
    parser.add_argument("--steps", default=10000, type=int,
                        help='Maximum number of training steps')
    parser.add_argument("--seeds", default=0, type=int,
                        help='Random seeds.')
    parser.add_argument("--batch_size", default=128, type=int,
                        help='Mini batch size.')
    parser.add_argument("--device", default="cpu", type=str,
                        help='Device: cpu vs cuda.')
    parser.add_argument("--task_type", default="classification", type=str,
                        help='Task type: regression, classification')
    parser.add_argument("--metrics", default="auroc", type=str,
                        help='Report metrics')
    parser.add_argument("--num_class", default=2, type=int,
                        help='Number of classes')
    parser.add_argument("--num_hidden_layers", default=1, type=int,
                        help='Number of layers of  the classifier or regression head')
    parser.add_argument("--hidden_size", default=512, type=int,
                        help='Hidden size of  the classifier or regression head')
    parser.add_argument("--val_steps", default=10, type=int,
                        help='Val steps.')
    parser.add_argument("--max_length", default=128, type=int,
                        help='Max length')
    parser.add_argument('--checkpoint', default='../data/pytorch_model.bin', type=str,
                        help='AntBrain checkpoint')
    parser.add_argument('--vocab', default='../data/vocab.txt', type=str,
                        help='Vocab file for the smiles tokenizer')
    parser.add_argument('--data_folder', default='data/', type=str,
                        help='Data folder')
    parser.add_argument("--val_tolerance", default=3, type=int,
                        help='Val tolerance.')
    parser.add_argument("--val_improve_tolerance", default=0.01, type=float,
                        help='Val tolerance.')
    parser.add_argument("--output_folder", default="./", type=str,
                        help='Output folder')
    parser.add_argument("--optuna_patience", default=25, type=int,
                        help='The number of trails of no improvement before terminating the optuna process.')
    parser.add_argument("--split_type", default="random", type=str,
                        help='Type of splitting the train_val into folds: random, preserved_order, scaffold')

    args = parser.parse_args()
    print(args)
    # Initialize variables
    best_score = float('inf')
    no_improvement_count = 0
    set_seed(args.seeds)
    device = torch.device(args.device)
    tokenizer = SmilesTokenizer(vocab_file=args.vocab)
    print(args)
    checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu'))
    ant = AntBrain(model_name='bert-base-uncased', tokenizer=tokenizer)
    ant.load_state_dict(checkpoint, strict=False)

    is_log = False
    if args.metrics == "spearman":
        is_log = True
    train_data = args.data_folder + "/" + args.dataset.lower() + "/train_val.csv"
    train_data = DrugDataset(train_data, tokenizer, args.max_length, is_log=is_log)
    folds = train_data.split_to_folds(args.split_type)

    stats = {"best_test": None, "best_val": None, "best_prediction": None, "best_label": None}
    # Create a study object and optimize the objective function
    # pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=10, interval_steps=1)
    study = optuna.create_study(direction='minimize')
    study.optimize(lambda trial: train(trial, stats, folds), n_trials=100, callbacks=[early_stopping_callback])

    # Print the best hyperparameters and the corresponding accuracy
    best_params = study.best_params
    best_accuracy = study.best_value
    print("Best Hyperparameters:", best_params)
    print("Best Val:", best_accuracy)
    stats_output = {"best_test": stats["best_test"], "best_val": stats["best_val"]}
    print("Best stats", stats_output)
