from adalead_akash import Adalead
import sys
import numpy as np
import argparse
import time
sys.path.append("./psipred/s4pred")
# sys.path.append("/home/apa2237/generative_model_work/oracle_models/amp")
sys.path.append("/home/**/generative_model_work/datasets/toy_motifs_1/")
sys.path.append("/home/**/generative_model_work/importance_models/ss")
from predict_ss import oracle_prediction

# from produce_output import initalize,test
# from generate_property import output_property_oracle
from contributions_score_generative import initalize as initalize_contri


import json

import pickle
# parser = argparse.ArgumentParser()

# Command-line argument parsing
parser = argparse.ArgumentParser(description='Run Adalead with specified temperature.')
parser.add_argument('--temp', type=float, default=1, help='Temperature parameter for Adalead')
parser.add_argument('--device', type=str, default='cuda:6')
input_args = parser.parse_args()

print(f'========= Temperature = {input_args.temp} ===========')

alphabet = np.load('/home/**/generative_model_work/importance_models/ss/model/categorical_variables.npy', allow_pickle=True).tolist()
data_path = "/home/**/generative_model_work/importance_models/ss/data/"

top_per = 40

seq_start = np.load(f'{data_path}/seq_test.npy', allow_pickle=True).reshape((-1,1))
y_start = np.load(f'{data_path}/y_test.npy', allow_pickle=True).reshape((-1,1))

print('Before filtering',len(seq_start), len(y_start))
print('Mean before', np.mean(y_start))
print('Min before', np.min(y_start))
print('Max before', np.max(y_start))

cap = max(y_start)
floor = 0.1#min(y_start)
cutoff = (floor)*(100+top_per)/100
print('Cutoff', cutoff)
below_idx = (y_start<cutoff)
# print('Below_idx', np.sum(below_idx*1))
print(below_idx.shape,seq_start.shape)

seq_start = seq_start[below_idx]
y_start = y_start[below_idx]

print('Filtering==================')
print('After filtering',len(seq_start), len(y_start))
print('Mean after', np.mean(y_start))
print('Max after', np.max(y_start))

sorted_idx = np.argsort(y_start)[::-1]
seq_start = seq_start[sorted_idx]
print(seq_start.shape)
y_start = y_start[sorted_idx]



class args():
    device = input_args.device
    max_len = 52
model_args = args()

oracle = oracle_prediction(model_args.device)

# print(model)
print(model_args.device)

seq_start = seq_start.tolist()
start_fitness = oracle.batch_ss_pred(seq_start)#.tolist()
print('Mean after', np.mean(start_fitness))
print('Max after', np.max(start_fitness))
start_fitness = start_fitness.tolist()

start_num_samples = len(seq_start)

### loading contribution score model
model_contri, criterion_contri,optimizer_contri = initalize_contri(model_args.device)

# all_ng = [3,5,10,20,50,100]
all_ng = [20,50,100]
num_trials = 10 # number of trials with one ng value for reproducibility
rounds = 10 ## number of generation iteration in each trials

adalead_imp_code_book = {}
adalead_new_sequences = {}

def set_seeds(trial):
    seed = trial
    print(f"Seed: {seed}")
    np.random.seed(seed)

print('============= Are you happy with the choice of motif size ===========')
time.sleep(1.0)

for count_ng,ng in enumerate(all_ng):
    adalead_algo = Adalead(oracle,
    model_args,
    rounds = 1,
    sequences_batch_size = ng, #keeping sequences_batch_size == model_queries_per_batch
    model_queries_per_batch = ng,
    starting_sequence = None,
    alphabet = alphabet,
    mu = 1,
    recomb_rate = 0.2,
    threshold =  0.05,
    rho = 0,
    eval_batch_size = 20,
    model_contri = model_contri, ## model for contribution
    criterion_contri = criterion_contri, ## criterion for contribution
    optimizer_contri =  optimizer_contri, ## optimizer for contribution
    motif_size = 1, ## size of the motif to be mutated
    motif_based=True, ## if doing motif-based mutation
    )
    
    is_imp_based = True
    temp = input_args.temp ## temp value


    small_code_book = np.zeros((num_trials, rounds+1, 3))
    
    new_sequences = {trial: {} for trial in range(num_trials)}
    
    for trial in range(num_trials):
        starting_mean = np.mean(y_start[0:ng])
        set_seeds(trial)
        print(f'For Num_samples:{ng}, Iteration: {trial+1}')
        # print(np.mean(start_fitness))
        bundle = {"sequence":seq_start[0:start_num_samples], "true_score":start_fitness[0:start_num_samples]}
        m, maxx = np.mean(np.array(bundle["true_score"])), max(np.array(bundle["true_score"]))
        small_code_book[trial, 0, :] = [m,maxx,starting_mean]
        print(f'At the beginning of starting, Mean:{m}, Max:{maxx} ')
        for i in range(1,rounds+1):
            new_seq, new_prop = adalead_algo.propose_sequences(bundle,\
                is_imp_based, temp)
            bundle["sequence"].extend(new_seq.tolist())
            bundle["true_score"].extend(new_prop.tolist())
            m, maxx, new_mean = np.mean(np.array(bundle["true_score"])), \
                max(np.array(bundle["true_score"])), \
                np.mean(new_prop)

            # store the actual new sequences for this trial & round
            new_sequences[trial][i] = new_seq.tolist()

            small_code_book[trial,i,:] = [m,maxx,new_mean]
            print(f'Iteration {i}, Mean:{m}, Max:{maxx}, New Samples: {new_mean} ')   

        # print(new_seq)
            
    adalead_imp_code_book[ng] = small_code_book
    adalead_new_sequences[ng] = new_sequences  # <-- add this
    # print(aaaa)

    np.save(f'./generative_results/our_{input_args.temp}', adalead_imp_code_book)
    np.save(f'./generative_results/sequences_our_{input_args.temp}', adalead_new_sequences)