
import os
from math import ceil
import random
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import torch
from config import config
from dataset import ActiveLearningDataset
from model import Ensemble
from acquisition import acquire
from utils.utils import Evaluate, random_baseline, random_baseline_pcba



def active_learning(args):
    DS = ActiveLearningDataset(args)

    ### Start
    DS.get_start_data()

    eval_screen, eval_train = Evaluate(), Evaluate()
    total_hit_discover, total_mol_screen = [], []

    total_hit_discover.append(sum(DS.y_train).item())
    total_mol_screen.append(len(DS.y_train))
    print("Current Hit Num:", total_hit_discover[-1])


    cycles = ceil((args.max_screen_size - args.start_num) / args.batch_size)
    for cycle_i in tqdm(range(1, cycles+1), desc="cycle"):


        ### Train
        train_loader, train_loader_balanced, screen_loader = DS.construct_dataloader()

        print("Training Model!")
        if (args.retrain == 1 or cycle_i == 1):
            model = Ensemble(args)
            model.train(train_loader_balanced)
        
            if args.mode == "e":
                checkpoint = model.models[0].model
                torch.save(checkpoint.state_dict(), args.model_save_file)
        

        ### Eval
        print("Inferring Screen Data!")
        train_logits_N_K_C = model.predict(train_loader)
        eval_train.eval(train_logits_N_K_C, DS.y_train)

        screen_logits_N_K_C = model.predict(screen_loader)
        eval_screen.eval(screen_logits_N_K_C, DS.y_screen)


        ### Query
        if len(DS.idx_train) + args.batch_size < args.max_screen_size:
            batch_size = args.batch_size
        else:
            batch_size = args.max_screen_size - len(DS.idx_train)

        if args.strategy == "grpo":   # Strategy is grpo, then sample data by policy model
            mean_probs_hits = torch.mean(torch.exp(screen_logits_N_K_C), dim=1)[:, 1].cpu()

            # Both mean_probs_hits and random_val in [0, 1], greater mean_probs_hits then pick_flag will more likely be 1
            random_val = torch.rand(mean_probs_hits.shape[0])
            pick_flag = torch.where(mean_probs_hits - random_val > 0, torch.ones_like(mean_probs_hits), torch.zeros_like(mean_probs_hits))

            # mean_probs_hits_with_flag of selected mol must be greater than unselected even if mean_probs_hits is less, since pick_flag of selected mol is 1
            mean_probs_hits_with_flag = mean_probs_hits + pick_flag

            # sort the mean_probs_hits_with_flag to get top-n selected mol
            idx_pick = torch.argsort(mean_probs_hits_with_flag, descending=True)[:batch_size]
            smiles_pick = DS.smiles_screen[idx_pick.cpu()]
        else:   # Strategy is not grpo, then acquire data using acquisition function
            smiles_pick = acquire(acquisition=args.strategy, logits_N_K_C=screen_logits_N_K_C, smiles_screen=DS.smiles_screen, n=batch_size, smiles_hit=DS.smiles_train[np.where(DS.y_train == 1)])
        

        ### Expand
        DS.add_data(smiles_pick)

        total_hit_discover.append(sum(DS.y_train).item())
        total_mol_screen.append(len(DS.y_train))
        print("Current Hit Num:", total_hit_discover[-1])


    ### Information
    train_results = eval_train.to_dataframe("train_")
    screen_results = eval_screen.to_dataframe('screen_')

    results = pd.concat([train_results, screen_results], axis=1)
    results['total_hit_discover'] = total_hit_discover
    results['total_mol_screen'] = total_mol_screen

    if args.mode in ["a", "e"]:
        tot_dataset_hits = {"ALDH1": 4986, "PKM2": 223, "VDR": 239, "Enamine50k": 500, "EnamineHTS": 1000}
        baseline = random_baseline(tot_dataset_hits[args.dataset], args.batch_size, args.start_num, DS.y.shape[0], args.max_screen_size)
    elif args.mode == "d":
        tot_dataset_hits = {"ADRB2": 17, "ALDH1": 5363, "ESR1_ago": 13, "ESR1_ant": 88, "FEN1": 360, "GBA": 163, "IDH1": 39, "KAT2A": 194, "MAPK1": 308, "MTORC1": 97, "OPRK1": 24, "PKM2": 546, "PPARG": 24, "TP53": 64, "VDR": 655}
        baseline = random_baseline_pcba(tot_dataset_hits[args.dataset], args.batch_size, args.start_num, DS.y.shape[0], args.max_screen_size)
    
    results["enrichment_factor"] = list(np.array(total_hit_discover) / np.array(baseline))

    return results



if __name__ == '__main__':

    args = config()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda

    results = active_learning(args)
    
    os.makedirs(os.path.join("experiments", args.output_folder), exist_ok=True)
    log_file = f"experiments/{args.output_folder}/{args.architecture}_{args.strategy}_{args.dataset}_{args.seed}_results.csv"
    results.to_csv(log_file, mode='a', index=False, header=False if os.path.isfile(log_file) else True)









