import os
import random

# S2S model train in batches

def train_model(train_data: List[Example], input_indexer, output_indexer,
                name='ml_model', epochs=2, k=0):
    for d in range(0, len(train_data), len(train_data)//10):
        model = torch.load(cwd+'models/'+name)
        optimizer = torch.optim.Adam(model.parameters(), lr=LR)
        train = train_data[d: d+len(train_data)//10]
        for i in range(EPOCH):
            random.shuffle(train)
            total_loss = 0
            for j in range(0, len(train), BATCH_SIZE):
                model.zero_grad()
                batch_exs = train[j: j+BATCH_SIZE]

                x_inp_len, y_inp_len = [],[]
                for ex in batch_exs:
                    x_inp_len.append(len(ex.x_tok))
                    y_inp_len.append(len(ex.y_tok)+1) # include EOS
                x_tensor = make_padded_input_tensor(batch_exs, input_indexer, max(x_inp_len), reverse_input=False)
                y_tensor = make_padded_output_tensor(batch_exs, output_indexer, max(y_inp_len))
                x_inp_len, y_inp_len = torch.tensor(x_inp_len), torch.tensor(y_inp_len)

                loss = model.forward(torch.tensor(x_tensor), x_inp_len, torch.tensor(y_tensor), y_inp_len)
                total_loss += loss
                loss.backward()
                optimizer.step()
        torch.save(model, cwd+'models/'+name)
    return model

model = train_model(train_data_indexed, input_indexer, output_indexer)

# S2S performance vs attack accuracy
models = []
attack_model = torch.load(cwd+'models/att_model')
for i in range(500, 5001, 500):
    models.append(train_model(train_data_indexed[:i], input_indexer, 
                              output_indexer, 'ml_ds_'+str(i), epochs=10))
    
for m in models:
    evaluate_ml(dev_data_indexed, m)
    print()

ben_accs = []
for m in models:
    data_vec = output_vectors(train_data_indexed[:500], dev_data_indexed, shadow_model)
    labels = [ex.y_indexed for ex in train_data_indexed[:500]] + [ex.y_indexed for ex in dev_data_indexed]
    ranks = get_ranks(data_vec, labels)
    data, labels = get_att_data(ranks[:500], ranks[500:])
    clf = audit_model(ranks, labels)
    data_vec = output_vectors(train_data_indexed[500:1000], dev_data_indexed, m)
    labels = [ex.y_indexed for ex in train_data_indexed[500:1000]] + [ex.y_indexed for ex in dev_data_indexed]
    ranks = get_ranks(data_vec, labels)
    data, labels = get_att_data(ranks[:500], ranks[500:])
    ben_accs.append(eval_audit_model(clf, data, labels))

att_accs = []
for m in models:
    label_vec = label_vectors(train_data_indexed[:500], dev_data_indexed, m)
    data_vec = output_vectors(train_data_indexed[:500], dev_data_indexed, m)
    all_data = [np.concatenate([label_vec[i], data_vec[i]], axis=0) for i in range(len(label_vec))]
    v_data, v_labels = get_att_data(all_data[:500], all_data[500:])
    att_accs.append(evaluate_attack(v_data, v_labels, attack_model))

# S2S epochs vs attack accuracy
models = []
att_accs = []
for i in range(10, 101, 10):
    models.append(train_model(train_data_indexed[:len(dev_data_indexed)], input_indexer, 
                              output_indexer, 'ml_epoch_'+str(i), epochs=i))
    
for m in models:
    evaluate_ml(dev_data_indexed, m)
    print()

i = 10
for m in models:
	attack_model = torch.load(cwd+'models/att_epoch_'+str(i))
    label_vec = label_vectors(train_data_indexed[:len(dev_data_indexed)], dev_data_indexed, m)
    data_vec = output_vectors(train_data_indexed[:len(dev_data_indexed)], dev_data_indexed, m)
    all_data = [np.concatenate([label_vec[i], data_vec[i]], axis=0) for i in range(len(label_vec))]
    v_data, v_labels = get_att_data(all_data[:len(dev_data_indexed)], all_data[len(dev_data_indexed):])
    att_accs.append(evaluate_attack(v_data, v_labels, attack_model))

# RL epochs vs attack accuracy
net = 'cnn'
for i in range(20480, 204800, 20480):
	os.system('python -m scripts.train --env MiniGrid-Multirooms-v00 --algo ppo --model valid/mr_' + net + str(i//2048) + ' --frames ' + str(i))

for i in range(20480, 204800, 20480):
	os.system('python -m scripts.train --env MiniGrid-Multirooms-v00 --algo ppo --model valid/mr_' + net + str(i//2048) + ' --frames ' + str(204800) + ' --test 1')
	

att_accs = []
for i in range(10, 101, 10):
	os.system('python -m scripts.train --env MiniGrid-Multirooms-v1 --algo ppo --model mr_shadow_1 --frames ' + str(i*2048) + ' --test 0')
	os.system('python -m scripts.train --env MiniGrid-Multirooms-v1 --algo ppo --model mr_shadow_1 --frames ' + str(204800) + ' --test 1')
	os.system('python -m scripts.train --env MiniGrid-Multirooms-v5 --algo ppo --model mr_shadow_5 --frames ' + str(i*2048) + ' --test 0')
	os.system('python -m scripts.train --env MiniGrid-Multirooms-v5 --algo ppo --model mr_shadow_5 --frames ' + str(204800) + ' --test 1')

	df = pd.read_csv(cwd+'data/mr_label_1/probabilities.csv')
	df0 = pd.read_csv(cwd+'data/mr_label_5/probabilities.csv')
	df1 = pd.read_csv(cwd+'data/mr_shadow_1/probabilities.csv')
	df2 = pd.read_csv(cwd+'data/mr_shadow_5/probabilities.csv')

	data0 = reshape_data(df)
	data00 = reshape_data(df0)
	data1 = reshape_data(df1)
	data2 = reshape_data(df2)

	attack_model = build_att()
	attack_model.fit(x=data, y=labels, batch_size=64, epochs=15, verbose=False)
	victim_df = pd.read_csv(cwd+'data/valid/mr_'+net+str(i)+'/probabilities.csv')
	victim = reshape_data(victim_df)
	att_accs.append(attack_accuracy(attack_model, victim, reshape_data(pd.read_csv(cwd+'data/mr1/probabilities.csv')), data2))

# RL batches vs attack accuracy

for i in range(18):
    os.system('python -m scripts.train --env MiniGrid-Multirooms-v'+str(i)+' --algo ppo --model mr_label_'+str(i)+' --frames ' + str(204800) + ' --test 0')
    os.system('python -m scripts.train --env MiniGrid-Multirooms-v'+str(i)+' --algo ppo --model mr_label_'+str(i)+' --frames ' + str(204800) + ' --test 1')
    
rand = random.shuffle(list(range(16)))
for i in rand:
    os.system('python -m scripts.train --env MiniGrid-Multirooms-v'+str(i)+' --algo ppo --model mr --frames ' + str(204800//16) + ' --test 0 --recurrence 16')

for i in rand:
    os.system('python -m scripts.train --env MiniGrid-Multirooms-v'+str(i)+' --algo ppo --model rnn/mr_rnn_'+str(i)+' --frames ' + str(204800) + ' --test 1 recurrence 16')

attack_model = build_att()
attack_model.load_weights(cwd+'data/attack_model')

df_cnn = reshape_data(pd.read_csv(cwd+'data/rnn/mr_cnn_1/probabilities.csv'))
df_cnn_2 = reshape_data(pd.read_csv(cwd+'data/rnn/mr_cnn_2/probabilities.csv'))
df_rnn = reshape_data(pd.read_csv(cwd+'data/rnn/mr_rnn_1/probabilities.csv'))
df_rnn_2 = reshape_data(pd.read_csv(cwd+'data/rnn/mr_rnn_2/probabilities.csv'))

df_c1 = reshape_data(pd.read_csv(cwd+'data/mr_label_1/probabilities.csv'))
df_c2 = reshape_data(pd.read_csv(cwd+'data/mr_label_17/probabilities.csv'))
df_r1 = reshape_data(pd.read_csv(cwd+'data/mr_label_1/probabilities.csv'))
df_r2 = reshape_data(pd.read_csv(cwd+'data/mr_label_17/probabilities.csv'))

acc1 = attack_accuracy(df_cnn, df_c2, df_c1, attack_model)
acc2 = attack_accuracy(df_cnn_2, df_c1, df_c2, attack_model)

acc1r = attack_accuracy(df_rnn, df_r2, df_r1, attack_model)
acc2r = attack_accuracy(df_rnn_2, df_r1, df_r2, attack_model)
