import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
import torch.nn.functional as F
import torch


mode = 'proxy3_gold1'
# mode = 'proxy1_gold3'
mode = 'utility_scores_gpt_proxy1_gold3_logtrain_100_deberta-v3-large.csv'
mode = 'old_csvs/fifth_run/utility_scores_gpt_proxy1_gold3_separatedata_100_deberta-v3-large.csv'
# mode = 'morality_scores_gpt_cm_all_150_deberta-v3-large_copy.csv'
# mode = 'morality_scores_gpt_cm_immoral_150_deberta-v3-large.csv'
# mode = 'morality_scores_gpt_cm_multi_all_150_deberta-v3-large.csv'
# mode = 'morality_scores_gpt_cm_multi_immoral_150_deberta-v3-large.csv'
# mode = 'morality_scores_gpt_cm_all_threshold_150_deberta-v3-large.csv'
# mode = 'morality_scores_gpt_cm_immoral_threshold_150_deberta-v3-large.csv'
# mode = 'morality_scores_llama_cm_immoral_flipped_threshold_150_deberta-v3-large.csv'
# mode = 'morality_scores_llama_cm_immoral_threshold_150_deberta-v3-large.csv'
# mode = 'morality_scores_llama_cm_immoral_rerun_threshold_150_deberta-v3-large.csv'
# mode = 'morality_scores_llama_cm_immoral_150_deberta-v3-large.csv'
# mode = 'morality_scores_llama_cm_immoral_oldtemp_150_deberta-v3-large.csv'
# mode = 'morality_scores_llama_cm_untrained_50_deberta-v3-large.csv'
# mode = 'morality_scores_llama_cm_immoral_bigbatch_50_deberta-v3-large.csv'
# mode = 'hh_scores_llama_bigbatch_128_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_subset_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_moredata_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_score_256_deberta-v3-xsmall.csv'
# mode = 'old_csvs/seventh_run/hh_scores_gpt_6layer_256_deberta-v3-xsmall.csv'
# mode = 'old_csvs/seventh_run/hh_scores_gpt_largelr_256_deberta-v3-xsmall.csv'
# mode = 'old_csvs/seventh_run/hh_scores_gpt_agres_256_deberta-v3-xsmall.csv'
# mode = 'old_csvs/seventh_run/hh_scores_llama_256_deberta-v3-xsmall.csv'
# mode = 'old_csvs/seventh_run/hh_scores_llama_sanity_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_sanity2_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_sanity2_1gpu_smallkl_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_promptnorm_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_nonorm_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_sanity_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_sanity_largelr_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_sanity_largelr_complete_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_sanity_debug_256_deberta-v3-xsmall.csv'

# mode = 'hh_scores_gpt_subset_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_moredata_256_deberta-v3-xsmall.csv'
mode = 'old_csvs/eighth_run/hh_scores_gpt_moredata_2layer_256_deberta-v3-xsmall.csv' # 80
# mode = 'old_csvs/eighth_run/hh_scores_gpt_moredata_2layer_4gpu_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_moredata_2layer_emptynorm_256_deberta-v3-xsmall.csv'
# mode = 'morality_scores_gpt_cm_immoral_64_deberta-v3-large.csv'
# mode = 'old_csvs/eighth_run/hh_scores_gpt_moredata_2layer_256_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_moredata_2layer_largerkl_256_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_moredata_2layer_largerkl_rerun_256_deberta-v3-large.csv'
# mode = 'hh_scores_llama_moredata_2layer_largerkl_256_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_moredata_2layer_nokl_256_deberta-v3-large.csv'
# mode = 'hh_adversary_scores_256_deberta-v3-large.csv'
# mode = 'hh_adversary_scores_nobreak_256_deberta-v3-large.csv'

# mode = 'old_csvs/eighth_run/hh_scores_gpt_normalkl_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_normalkl_smalldata_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_normalkl_bigdata_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_bigkl_bigdata_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_biggerkl_bigdata_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_biggerkl_smalldata_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_biggestkl_bigdata_256_deberta-v3-xsmall.csv'

# mode = 'hh_scores_gpt_noeos_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_weos_256_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_weos_bigkl_256_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_weos_bigkl_256_deberta-v3-xsmall.csv'
# mode = 'old_csvs/eighth_run/hh_scores_gpt_weos_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_weos_bigkl_256_deberta-v3-xsmall.csv'

# mode = 'old_csvs/eighth_run/hh_scores_gpt_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_add1_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_add1_accu2ep1_256_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_16_1e-05_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_klclip_256_16_1e-05_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_16_1e-05_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_16_1e-05_0.1_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_16_1e-05_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_16_5e-06_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_3000_256_16_5e-06_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_16_5e-06_0.2_deberta-v3-large.csv'

# mode = 'hh_scores_gpt_3000_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_30000_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_16_5e-06_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_256_16_5e-06_0.2_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_3000_256_16_5e-06_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_256_4_1e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_klfix_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
mode = 'hh_scores_llama_klfix_256_4_1e-06_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_constantkl_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_constantkl_256_4_1e-06_0.2_deberta-v3-xsmall.csv'

# mode = 'hh_scores_gpt_vhead_10000_256_16_5e-06_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_vhead_10000_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_vhead_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_vhead_256_16_5e-06_0.2_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_vhead_10000_256_16_5e-06_0.2_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_vhead_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_vhead_256_16_5e-06_0.2_deberta-v3-large.csv'

# mode = 'hh_scores_llama_10000_256_4_1e-06_0.2_deberta-v3-large_old.csv'
# mode = 'hh_scores_llama_512_256_4_1e-06_0.2_deberta-v3-large.csv'
# mode = 'hh_scores_llama_10000_256_4_1e-06_0.2_deberta-v3-large.csv'
# mode = 'hh_scores_llama_512_nonorm_256_4_1e-06_0.2_deberta-v3-large.csv'
# mode = 'hh_scores_llama_512_final_nonorm_higherlr_256_4_5e-06_0.2_deberta-v3-large.csv'
# mode = 'hh_scores_llama_512_final_nonorm_higherlr_rerun_256_4_5e-06_0.2_deberta-v3-large.csv'
# mode = 'hh_scores_llama_512_final_nonorm_higherlr_rerun_sample_256_4_5e-06_0.2_deberta-v3-large.csv'
# mode = 'hh_scores_llama_fixed_256_4_2e-06_0.2_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_fixed_10000_256_16_5e-06_0.1_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_fixed_512_256_16_5e-06_0.1_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_fixed_512_256_4_2e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_fixed_10000_256_16_5e-06_0.05_deberta-v3-large.csv'
mode = 'hh_scores_gpt_long_10000_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_256_16_5e-06_0.02_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_256_16_5e-06_0.01_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_10000_nostop_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_fixedloss_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_10000_nostop_fixedloss_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_llama_10000_nostop_fixedloss_256_4_2e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_fixedloss_noeos_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_fixedloss_noscale_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_oldloss_noscale_256_16_5e-06_0.1_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_oldloss_noscale_divide3_256_16_5e-06_0.1_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_oldloss_noscale_divide3_256_16_5e-06_0.0_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_fixedloss_noscale_better_256_16_5e-06_0.0_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_normloss_noscale_better_256_16_5e-06_0.0_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_normloss_noscale_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_normloss_noscale_nonormscore_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_fixedloss_noscale_nonormscore_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_fixedloss_noscale_nonormscore_1epo_256_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_fixedloss_noscale_nonormscore_1epo_256_16_5e-06_0.02_deberta-v3-large.csv'
# mode = 'hh_scores_gpt_long_10000_stop_fixedloss_noscale_nonormscore_1epo_256_16_5e-06_0.05_deberta-v3-large.csv'

# mode = 'hh_scores_gpt_long_10000_nostop_normloss_noscale_nonormscore_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_normloss_noscale_nonormscore_constkl_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_normloss_noscale_nonormscore_constkl_subset_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_long_all_nostop_normloss_noscale_nonormscore_constkl_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_long_512_nostop_normloss_noscale_nonormscore_constkl_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_long_512_nostop_normloss_noscale_nonormscore_constclipkl_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_long_10000_nostop_normloss_noscale_nonormscore_twoheads_256_16_5e-06_0.05_deberta-v3-large.csv'

# mode = 'hh_scores_llama_10000_nostop_normloss_noscale_nonormscore_256_4_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_10000_nostop_normloss_noscale_nonormscore_nopid_256_4_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_10000_nostop_normloss_noscale_nonormscore_constkl_256_4_1e-06_0.1_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_512_nostop_normloss_noscale_nonormscore_constclipkl_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_512_nostop_normloss_noscale_nonormscore_constclipkl_8gpu_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_512_nostop_normloss_noscale_nonormscore_constclipkl_8gpu_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_512_nostop_normloss_noscale_nonormscore_constkl_rerun_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_512_nostop_normloss_noscale_nonormscore_constkl_8gpu_rerun_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_512_nostop_normloss_noscale_nonormscore_constkl_8gpu_rerun_fixed_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_all_nostop_normloss_noscale_nonormscore_constkl_subset_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_512_nostop_normloss_noscale_nonormscore_constkl_subset_OK_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_all_nostop_normloss_noscale_nonormscore_constkl_OK_256_4_1e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_all_nostop_normloss_noscale_nonormscore_constkl_OK_256_4_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_512_nostop_normloss_noscale_nonormscore_constkl_OK_256_4_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_512_nostop_normloss_noscale_nonormscore_constkl_OK_256_4_2e-06_0.0_deberta-v3-xsmall.csv'
mode = 'hh_scores_gpt_long_all_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_all_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_single_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt_long_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.01_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_single_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.01_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_single_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.0_deberta-v3-xsmall.csv'

# mode = 'hh_scores_gpt_long_unfrozen_single_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_single4_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.0_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_single4_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.0_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_single4_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall.csv'

# mode = 'hh_scores_llama_v1.1_512_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_512_pg_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_single16_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_single32_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_single64_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall.csv'

# mode = 'hh_scores_llama_v1.1_single16_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_5e-06_0.05_deberta-v3-xsmall.csv'

# mode = 'hh_scores_llama_v1.1_all_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_7gpu_256_4_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_v1.1_all_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_7gpu_256_4_2e-06_0.05_deberta-v3-xsmall.csv'

# mode = 'hh_scores_llama_v1.1_all_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-large.csv'

mode = 'old_csvs/tenth_run/' + mode

mode = 'hh_scores_llama_128_8_2e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_llama_0_128_8_2e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_llama_8192_128_8_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_llama_0_128_8_2e-06_0.05_deberta-v3-xsmall.csv'

mode = 'hh_scores_gpt2_8192_128_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt2_0_128_16_5e-06_0.05_deberta-v3-large.csv'
# mode = 'hh_scores_gpt2_8192_128_16_5e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_gpt2_0_128_16_5e-06_0.05_deberta-v3-xsmall.csv'

# mode = '/data/private_models/[ANONYMIZED]_models/misc/debug_runs/hh_scores_llama_v1.1_512_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall_llama7bhh_llama30bhh.csv'
# mode = '/data/private_models/[ANONYMIZED]_models/misc/debug_runs/hh_scores_llama_v1.1_512_clipkl_nostop_normloss_noscale_nonormscore_constkl_OK_4gpu_256_4_2e-06_0.05_deberta-v3-xsmall_llama30blorahh_llama30bhh.csv'

# mode = 'hh_scores_seed1_llama_512_128_8_2e-06_0.05_deberta-v3-xsmall.csv'
# mode = 'hh_scores_seed42_llama_512_128_8_2e-06_0.05_deberta-v3-xsmall.csv'

mode = 'old_csvs/elev_run/' + mode

mode = 'hh_scores_pythia_512_128_16_2e-06_0.05_deberta-v3-xsmall.csv'
mode = 'hh_scores_old_pythia_512_128_16_2e-06_0.05_deberta-v3-xsmall.csv'

df = pd.read_csv(mode, header=None, on_bad_lines='skip')
N = 1024
# print(df.groupby(df.index // N).mean(numeric_only=True)[3].max())
df.groupby(df.index // N).mean(numeric_only=True)#.iloc[-60:]



df[2] = df[8]
df[3] = df[8]

df[4] = df[7]
df[5] = df[7]


df.head()


c = 0
for pc, cc in zip(prompt_chunk, chunk):
#     c += (pc[1].split('Assistant:')[-1].strip() == '')
    if cc[3] < 0:
#         c += 1
#         print(pc)
        pass
    if cc[1] > 10:
        c += 1
        print(cc[1])


pc, cc = prompt_chunk, chunk


pc1, cc1 = prompt_chunk, chunk


pc2, cc2 = prompt_chunk, chunk


i = 4392
pc2[i][0], pc1[i][0]


pc.shape


c = []
for i, (a1, b1, a2, b2) in enumerate(zip(pc, cc, pc1, cc1)):
#     if a1[0] != a2[0]:
#         c.append(i)
    if a1[0] not in pc1[:,0]:
        c.append(i)
len(c)


tmp = []
for i, (a1, b1, a2, b2) in enumerate(zip(pc, cc, pc1, cc1)):
    if a1[0] in pc1[:,0]:
        idx = np.where(pc1[:,0] == a1[0])[0][0]
        tmp.append(b1[1] - cc1[idx][1])
        if b1[1] - cc1[idx][1] > 1 and b1[1] - cc1[idx][1] < 1.1:
            print('===\n', b1[[1,3]], a1[1].replace('\n\n','\n'), '\n===\n', cc1[idx][[1,3]], pc1[idx][1].replace('\n\n','\n'))
len(c)


tmp = []

for i, (a1, b1, a2, b2) in enumerate(zip(pc, cc, pc1, cc1)):
    tmp.append(b1[1] - b2[1])
    if b1[1] - b2[1] > 3.5:
        print('===\n', a1[1], '\n===\n', a2[1])
        break


(np.array(tmp) > 3).sum()


np.mean(tmp)


plt.hist(tmp)
plt.show()


plt.hist(chunk[:,0], 20)
plt.show()


j = 8
prompt_chunk[j]


chunk[j]


curr = 0
step = 10
c = [0]
for i in range(100):
    curr += step
    c.append(curr)
    if curr / step == 10:
        step *= 10
        step = min(step, 10000)
    if curr > 200000: break
c


list(zip(c, range(len(c))))


curr = 0
step = 20
c = [0]
for i in range(100):
    curr += step
    c.append(curr)
    if curr / step == 5:
        step *= 10
        curr = 0
    if curr > 1000000: break

thresholding = False
train_test_split = True
if train_test_split:
    split_idx = N // 2

x, y_proxy, y_gold, kl = [],[],[],[]
y_proxy_train, y_gold_train, kl_train = [],[],[]
y_proxy_individual, y_gold_individual = [],[]
y_proxy_individual_train, y_gold_individual_train = [],[]
responses = []
for i in range(len(df) // N):
    prompt_chunk = df.iloc[i*N: (i+1)*N].values[:,:2].astype(str)
    c = t = tt = w = 0
    for p in prompt_chunk:
        c += (p[1].split('Assistant:')[-1].strip() == '')
        t += len(p[1].split('Assistant:')[-1].strip().split())
        tt += len(p[1].split())
        w += p[1].count('Assistant:')
    print(i, c, w, t/len(prompt_chunk), tt/len(prompt_chunk))
    chunk = df.iloc[i*N: (i+1)*N].values[:,2:].astype(float)
#     chunk[:,1] = chunk[:,1] - chunk[:,0]
#     chunk[:,3] = chunk[:,3] - chunk[:,2]
    chunk[:,1] /= chunk[:,0].std()
    chunk[:,3] /= chunk[:,2].std()
#     chunk[:,1] /= 2
#     chunk[:,3] /= 8
#     print(chunk[:,0].std(), chunk[:,2].std())
    max_diff = (chunk[:,1] - chunk[:,3]).argmax()
#     x.append(c[i])
    x.append(i)
    
    if thresholding:
        chunk[:,1] = chunk[:,1] > 0
        chunk[:,3] = chunk[:,3] > 0
    
    responses.append(prompt_chunk)
    if train_test_split:
        y_proxy.append(chunk[:split_idx,1].mean())
        y_proxy_individual.append(chunk[:split_idx,1])
        y_gold.append(chunk[:split_idx,3].mean())
        y_gold_individual.append(chunk[:split_idx,3])
        kl.append(min(chunk[:split_idx,4].mean(), 400))
        y_proxy_train.append(chunk[split_idx:,1].mean())
        y_proxy_individual_train.append(chunk[split_idx:,1])
        y_gold_train.append(chunk[split_idx:,3].mean())
        y_gold_individual_train.append(chunk[split_idx:,3])
        kl_train.append(min(chunk[split_idx:,4].mean(), 400))
    else:
        y_proxy.append(chunk[:,1].mean())
        y_proxy_individual.append(chunk[:,1])
        y_gold.append(chunk[:,3].mean())
        y_gold_individual.append(chunk[:,3])
        kl.append(min(chunk[:,4].mean(), 400))
if train_test_split:
    y_proxy_individual_train = np.array(y_proxy_individual_train).T
    y_gold_individual_train = np.array(y_gold_individual_train).T
y_proxy_individual = np.array(y_proxy_individual).T
y_gold_individual = np.array(y_gold_individual).T
responses = np.array(responses)


y_proxy_train[-1], y_gold_train[-1], y_proxy[-1], y_gold[-1]


y_proxy_train[-1] - y_gold_train[-1], y_proxy[-1] - y_gold[-1]


y_gold += y_proxy[0] - y_gold[0]
if train_test_split:
    y_gold_train += y_proxy_train[0] - y_gold_train[0]

fig, ax1 = plt.subplots()

lines = []
line, = ax1.plot(x, y_proxy, label="proxy", color='blue')
lines.append(line)
line, = ax1.plot(x, y_gold, label="gold", color='orange')
lines.append(line)
if train_test_split:
    line, = ax1.plot(x, y_proxy_train, label="proxy_train", linestyle='--', color='blue')
    lines.append(line)
    line, = ax1.plot(x, y_gold_train, label="gold_train", linestyle='--', color='orange')
    lines.append(line)

ax2 = ax1.twinx()
line, = ax2.plot(x, kl, label="kl", color='g')
lines.append(line)

ax1.set_xlabel('Steps')
ax1.set_ylabel('Reward')
ax2.set_ylabel('KL', color='g')
ax2.set_ylim(0,1000)

# ax1.set_xscale('log')
# ax2.set_xscale('log')

labels = [line.get_label() for line in lines]
plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')
plt.show()

print(y_proxy_train[-1] - y_gold_train[-1])
print(y_proxy[-1] - y_gold[-1])


n = 10
cutoff = 100
window_size = 10
page = 0
fig, axs = plt.subplots(n, n, figsize=(12, 12))

cutoff = min(cutoff, y_proxy_individual.shape[1])
lines = []
for i,(p,g) in enumerate(zip(y_proxy_individual, y_gold_individual)):
    if i < page*n*n: continue
    if i >= (page+1)*n*n: break
    row, col = divmod(i - page*n*n, n)
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    axs[row, col].plot(range(cutoff), p[:cutoff], label="proxy")
    axs[row, col].plot(range(cutoff), g[:cutoff], label="gold")
    axs[row, col].set_title(f"{row*n+col}", fontsize=6)
    axs[row, col].set_xticks([])
    axs[row, col].set_yticks([0])

# plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=2)

plt.show()


j = 60
j += page*n*n
for i, (res, pscore, gscore) in enumerate(zip(responses[:, :N//(2 if train_test_split else 1)], y_proxy_individual.T, y_gold_individual.T)):
    if i > cutoff: break
    print('='*10)
    print(pscore[j], gscore[j], res[j][1].replace('\n\n', '\n'))


n = 20
cutoff = 60
start = 0
window_size = 10
page = 0
fig, axs = plt.subplots(n, n, figsize=(12, 12))

cutoff = min(cutoff, y_proxy_individual.shape[1])
for i,(p,g) in enumerate(zip(y_proxy_individual_train, y_gold_individual_train)):
    if i < page*n*n: continue
    if i >= (page+1)*n*n: break
    row, col = divmod(i - page*n*n, n)
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    
    axs[row, col].plot(range(cutoff), p[start:cutoff], label="proxy")
    axs[row, col].plot(range(cutoff), g[start:cutoff], label="gold")
#     axs[row, col].set_title(f"{row*n+col}", fontsize=6)
    axs[row, col].set_xticks([])
    axs[row, col].set_yticks([])
#     axs[row, col].set_ylim(-3,3)

# plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=2)

plt.show()


n = 20
cutoff = 150
window_size = 1
page = 0
fig, axs = plt.subplots(n, n, figsize=(12, 12))

cutoff = min(cutoff, y_proxy_individual.shape[1])
for i,(p,g) in enumerate(zip(y_proxy_individual_train, y_gold_individual_train)):
    if i < page*n*n: continue
    if i >= (page+1)*n*n: break
    row, col = divmod(i - page*n*n, n)
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    
    axs[row, col].plot(range(cutoff), p[:cutoff], label="proxy")
    axs[row, col].plot(range(cutoff), g[:cutoff], label="gold")
#     axs[row, col].set_title(f"{row*n+col}", fontsize=6)
    axs[row, col].set_xticks([])
    axs[row, col].set_yticks([])
#     axs[row, col].set_ylim(-3,3)

# plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=2)

plt.show()


j = 7
j += page*n*n
for i, (res, pscore, gscore) in enumerate(zip(responses[:,N//2:], y_proxy_individual_train.T, y_gold_individual_train.T)):
    if i > cutoff: break
    print('='*10)
    print(pscore[j], gscore[j], res[j][1].replace('\n\n', '\n'))


n = 13
start, cutoff = 0, 150
window_size = 5
threshold = -0.

low_corr, low_corr_train = [],[]
normal_p, normal_g = [],[]
gamed_p, gamed_g = [],[]
normal_p_train, normal_g_train = [],[]
gamed_p_train, gamed_g_train = [],[]
cutoff = min(cutoff, y_proxy_individual.shape[1])
for i, tup in enumerate(zip(
    y_proxy_individual_train, y_gold_individual_train, y_proxy_individual, y_gold_individual)):
    vecs = []
    for vec in tup:
        vec = pd.Series(vec).rolling(window_size, min_periods=1, center=True)
        vecs.append(vec.mean())
    vecs[1] += vecs[0][0] - vecs[1][0]
    vecs[3] += vecs[2][0] - vecs[3][0]
    if np.corrcoef(vecs[0][start:cutoff],vecs[1][start:cutoff])[0,1] < threshold:
        low_corr_train.append(i)
        gamed_p_train.append(vecs[0])
        gamed_g_train.append(vecs[1])
    else:
        normal_p_train.append(vecs[0])
        normal_g_train.append(vecs[1])
    if np.corrcoef(vecs[2][start:cutoff],vecs[3][start:cutoff])[0,1] < threshold:
        low_corr.append(i)
        gamed_p.append(vecs[2])
        gamed_g.append(vecs[3])
    else:
        normal_p.append(vecs[2])
        normal_g.append(vecs[3])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,4))
print('GAMED:', len(low_corr_train), len(low_corr_train) / len(y_proxy_individual))
print('GAMED:', len(low_corr), len(low_corr) / len(y_proxy_individual))
learned_p, learned_g = np.array(normal_p).mean(0), np.array(normal_g).mean(0)
learned_p_train, learned_g_train = np.array(normal_p_train).mean(0), np.array(normal_g_train).mean(0)
learned_p_train_std, learned_g_train_std = np.array(normal_p_train).std(0)/10, np.array(normal_g_train).std(0)/10
ax1.plot(range(cutoff), learned_p_train[:cutoff], label="learned_proxy_train", linestyle='--', color='blue')
ax1.plot(range(cutoff), learned_g_train[:cutoff], label="learned_gold_train", linestyle='--', color='orange')
# ax1.fill_between(range(cutoff), (learned_p_train - learned_p_train_std)[:cutoff], (learned_p_train + learned_p_train_std)[:cutoff])
# ax1.fill_between(range(cutoff), (learned_g_train - learned_g_train_std)[:cutoff], (learned_g_train + learned_g_train_std)[:cutoff])
ax1.plot(range(cutoff), learned_p[:cutoff], label="learned_proxy", color='blue')
ax1.plot(range(cutoff), learned_g[:cutoff], label="learned_gold", color='orange')
ax1.set_title(f'Learned subset ({100 - round(len(low_corr) / len(y_proxy_individual) * 100, 1)}%)')
p, g = np.array(gamed_p).mean(0), np.array(gamed_g).mean(0)
p_train, g_train = np.array(gamed_p_train).mean(0), np.array(gamed_g_train).mean(0)
gamed_p_train_std, gamed_g_train_std = np.array(gamed_p_train).std(0)/10, np.array(gamed_g_train).std(0)/10
ax2.plot(range(cutoff), p_train[:cutoff], label="gamed_proxy_train", linestyle='--', color='blue')
ax2.plot(range(cutoff), g_train[:cutoff], label="gamed_gold_train", linestyle='--', color='orange')
# ax2.fill_between(range(cutoff), (p_train - gamed_p_train_std)[:cutoff], (p_train + gamed_p_train_std)[:cutoff])
# ax2.fill_between(range(cutoff), (g_train - gamed_g_train_std)[:cutoff], (g_train + gamed_g_train_std)[:cutoff])
ax2.plot(range(cutoff), p[:cutoff], label="gamed_proxy", color='blue')
ax2.plot(range(cutoff), g[:cutoff], label="gamed_gold", color='orange')
ax2.set_title(f'Gamed subset ({round(len(low_corr) / len(y_proxy_individual) * 100, 1)}%)')

labels = [line.get_label() for line in lines]
plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')
plt.show()

print(p_train[:cutoff][-1] - g_train[:cutoff][-1])
print(p[:cutoff][-1] - g[:cutoff][-1])




len(set(a).intersection(low_corr_train)) / (len(a) + len(low_corr_train)) * 2


a = low_corr_train
len(a)


n = 12
cutoff = 100
start = 0
window_size = 15
page = 0
fig, axs = plt.subplots(n, n, figsize=(12, 12))

j = 0
cutoff = min(cutoff, y_proxy_individual.shape[1])
for i,(p,g) in enumerate(zip(y_proxy_individual_train, y_gold_individual_train)):
    if i < page*n*n: continue
    if i not in low_corr_train: continue
    if j >= (page+1)*n*n: break
    row, col = divmod(j - page*n*n, n)
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    j += 1
    
    axs[row, col].plot(range(cutoff), p[start:cutoff], label="proxy")
    axs[row, col].plot(range(cutoff), g[start:cutoff], label="gold")
#     axs[row, col].set_title(f"{row*n+col}", fontsize=6)
    axs[row, col].set_xticks([])
    axs[row, col].set_yticks([])
    axs[row, col].set_ylim(-3,3)

# plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=2)

plt.show()


n = 13
start, cutoff = 0, 40
window_size = 10
threshold = -0.

low_corr, low_corr_train = [],[]
normal_p, normal_g = [],[]
gamed_p, gamed_g = [],[]
normal_p_train, normal_g_train = [],[]
gamed_p_train, gamed_g_train = [],[]
cutoff = min(cutoff, y_proxy_individual.shape[1])
for i, tup in enumerate(zip(
    y_proxy_individual_train, y_gold_individual_train, y_proxy_individual, y_gold_individual)):
    vecs = []
    for vec in tup:
        vec = pd.Series(vec).rolling(window_size, min_periods=1, center=True)
        vecs.append(vec.mean())
    vecs[1] += vecs[0][0] - vecs[1][0]
    vecs[3] += vecs[2][0] - vecs[3][0]
    if np.corrcoef(vecs[0][start:cutoff],vecs[1][start:cutoff])[0,1] < threshold:
        low_corr_train.append(i)
        gamed_p_train.append(vecs[0])
        gamed_g_train.append(vecs[1])
    else:
        normal_p_train.append(vecs[0])
        normal_g_train.append(vecs[1])
    if np.corrcoef(vecs[2][start:cutoff],vecs[3][start:cutoff])[0,1] < threshold:
        low_corr.append(i)
        gamed_p.append(vecs[2])
        gamed_g.append(vecs[3])
    else:
        normal_p.append(vecs[2])
        normal_g.append(vecs[3])

# x = np.array(c)[range(cutoff)]
x = range(cutoff)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,4))
print('GAMED:', len(low_corr_train), len(low_corr_train) / len(y_proxy_individual))
print('GAMED:', len(low_corr), len(low_corr) / len(y_proxy_individual))
learned_p, learned_g = np.array(normal_p).mean(0), np.array(normal_g).mean(0)
learned_p_train, learned_g_train = np.array(normal_p_train).mean(0), np.array(normal_g_train).mean(0)
# learned_p_train_std, learned_g_train_std = np.array(normal_p_train).std(0)/10, np.array(normal_g_train).std(0)/10
ax1.plot(x, learned_p_train[:cutoff], label="learned_proxy_train", linestyle='--', color='blue')
ax1.plot(x, learned_g_train[:cutoff], label="learned_gold_train", linestyle='--', color='orange')
# ax1.fill_between(range(cutoff), (learned_p_train - learned_p_train_std)[:cutoff], (learned_p_train + learned_p_train_std)[:cutoff])
# ax1.fill_between(range(cutoff), (learned_g_train - learned_g_train_std)[:cutoff], (learned_g_train + learned_g_train_std)[:cutoff])
ax1.plot(x, learned_p[:cutoff], label="learned_proxy", color='blue')
ax1.plot(x, learned_g[:cutoff], label="learned_gold", color='orange')
ax1.set_title(f'Learned subset ({100 - round(len(low_corr) / len(y_proxy_individual) * 100, 1)}%)')
p, g = np.array(gamed_p).mean(0), np.array(gamed_g).mean(0)
p_train, g_train = np.array(gamed_p_train).mean(0), np.array(gamed_g_train).mean(0)
# gamed_p_train_std, gamed_g_train_std = np.array(gamed_p_train).std(0)/10, np.array(gamed_g_train).std(0)/10
ax2.plot(x, p_train[:cutoff], label="gamed_proxy_train", linestyle='--', color='blue')
ax2.plot(x, g_train[:cutoff], label="gamed_gold_train", linestyle='--', color='orange')
# ax2.fill_between(range(cutoff), (p_train - gamed_p_train_std)[:cutoff], (p_train + gamed_p_train_std)[:cutoff])
# ax2.fill_between(range(cutoff), (g_train - gamed_g_train_std)[:cutoff], (g_train + gamed_g_train_std)[:cutoff])
ax2.plot(x, p[:cutoff], label="gamed_proxy", color='blue')
ax2.plot(x, g[:cutoff], label="gamed_gold", color='orange')
ax2.set_title(f'Gamed subset ({round(len(low_corr) / len(y_proxy_individual) * 100, 1)}%)')

labels = [line.get_label() for line in lines]
plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')
plt.show()

print(p_train[:cutoff][-1] - g_train[:cutoff][-1])
print(p[:cutoff][-1] - g[:cutoff][-1])




tup[0]


tup[0][:3].mean()


vecs[0]


page = 0
# fig, axs = plt.subplots(n, n, figsize=(12, 12))
corrs = []
diffs = []

cutoff = min(cutoff, y_proxy_individual.shape[1])
j = 0
for i,(p,g) in enumerate(zip(y_proxy_individual_train, y_gold_individual_train)):
    if i < page*n*n: continue
    if j >= (page+1)*n*n: break
    row, col = divmod(j - page*n*n, n)
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    corr = np.corrcoef(p[start:cutoff],g[start:cutoff])[0,1]
    if np.isnan(corr):
        corr = 1
    corrs.append(corr)
    diffs.append((p[:cutoff] - g[:cutoff]).max())
#     if corr < threshold:
#         axs[row, col].plot(range(cutoff), p[:cutoff], label="proxy")
#         axs[row, col].plot(range(cutoff), g[:cutoff], label="gold")
#         axs[row, col].set_title(f"{row*n+col}", fontsize=6)
#         axs[row, col].set_xticks([0])
#         axs[row, col].set_yticks([0])
#         j += 1

# plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=2)

# plt.show()


np.mean(corrs)


corrs11 = corrs


corrs22 = corrs


corrs33 = corrs


corrs44 = corrs


bins = 50
# plt.hist(corrs11, bins, alpha=0.3)
# plt.hist(corrs22, bins, alpha=0.3)
plt.hist(corrs33, bins, alpha=0.3)
plt.hist(corrs44, bins, alpha=0.3)
plt.show()


bins = 50
plt.hist(corrs11, bins, alpha=0.3)
plt.hist(corrs22, bins, alpha=0.3)
# plt.hist(corrs33, bins, alpha=0.3)
# plt.hist(corrs44, bins, alpha=0.3)
plt.show()


np.mean(corrs11), np.mean(corrs22)


page = 0
# fig, axs = plt.subplots(n, n, figsize=(12, 12))
corrs = []
diffs = []

cutoff = min(cutoff, y_proxy_individual.shape[1])
j = 0
for i,(p,g) in enumerate(zip(y_proxy_individual_train, y_gold_individual_train)):
    if i < page*n*n: continue
    if j >= (page+1)*n*n: break
    row, col = divmod(j - page*n*n, n)
    corr_list = []
    for window_size in [1,3,5,10]:
        pp = pd.Series(p).rolling(window_size, min_periods=1, center=True)
        pp = pp.mean() #/ y_proxy_individual.std(1).mean()
        gg = pd.Series(g).rolling(window_size, min_periods=1, center=True)
        gg = gg.mean() #/ y_gold_individual.std(1).mean()
        gg += pp[0] - gg[0]
        corr = np.corrcoef(pp[start:cutoff],gg[start:cutoff])[0,1]
        if np.isnan(corr):
            corr = 1
        corr_list.append(corr)
    corrs.append(corr_list)
    diffs.append((p[:cutoff] - g[:cutoff]).max())
#     if corr < threshold:
#         axs[row, col].plot(range(cutoff), p[:cutoff], label="proxy")
#         axs[row, col].plot(range(cutoff), g[:cutoff], label="gold")
#         axs[row, col].set_title(f"{row*n+col}", fontsize=6)
#         axs[row, col].set_xticks([0])
#         axs[row, col].set_yticks([0])
#         j += 1

# plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=2)

# plt.show()


for i, corr_list in enumerate(corrs):
    print(i, corr_list)
#     if corr_list[-1] < 0: break
    if i > 20: break


i = 9
p,g = y_proxy_individual_train[i], y_gold_individual_train[i]
# p, g


plt.plot(range(len(p)), p)
plt.plot(range(len(g)), g)
plt.show()


window_size = 10
new_g = pd.Series(g).rolling(window_size, min_periods=1, center=True).mean()
plt.plot(range(len(g)), new_g)
plt.plot(range(len(g)), g)
plt.show()


window_size = 10
new_p = pd.Series(p).rolling(window_size, min_periods=1, center=True).mean()
plt.plot(range(len(p)), new_p)
plt.plot(range(len(p)), p)
plt.show()


list(zip(p, new_p))


window_size = 10
p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
p = p.mean() #/ y_proxy_individual.std(1).mean()
g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
g = g.mean() #/ y_gold_individual.std(1).mean()
# p = p[window_size // 2:-window_size // 2].values
# g = g[window_size // 2:-window_size // 2].values
g += p[0] - g[0]
print(np.corrcoef(p[start:cutoff],g[start:cutoff])[0,1])

plt.plot(range(len(p)), p)
plt.plot(range(len(g)), g)
plt.show()


corrs1 = corrs


corrs2 = corrs


corrs3 = corrs


plt.hist(corrs1, 50, alpha=0.3)
plt.hist(corrs2, 50, alpha=0.3)
plt.hist(corrs3, 50, alpha=0.3)
plt.show()


plt.hist(corrs, 20)
plt.show()


plt.figure(figsize=(4,4))

step = 0.1

corrs = np.array(corrs)
diffs = np.array(diffs)
learned, gamed = [],[]
for thres in np.arange(-1 + step, 1 + step, step):
    c = (corrs > thres).sum() / len(corrs)
    learned.append(c)
    if (corrs < thres).sum() != 0:
        c = max(0, diffs[corrs < thres].mean())
    else:
        c = 0
    gamed.append(c)

plt.plot(np.arange(-1 + step, 1 + step, step), gamed)
plt.gca().invert_xaxis()
plt.grid(axis='y', ls='dashed')
plt.grid(axis='x', ls='dashed')
plt.xlabel('correlation between proxy and gold')
plt.ylabel('% examples above threshold')
plt.title(f'AUC: {round(np.mean(gamed), 2)}')
plt.show()


plt.figure(figsize=(4,4))

corrs = np.array(corrs)
learned, gamed = [],[]
for thres in np.arange(-1., 1.05, 0.05):
    c = (corrs > thres).sum() / len(corrs)
    learned.append(c)

plt.plot(np.arange(-1, 1.05, 0.05), learned)
plt.gca().invert_xaxis()
plt.grid(axis='y', ls='dashed')
plt.grid(axis='x', ls='dashed')
plt.xlabel('correlation between proxy and gold')
plt.ylabel('% examples above threshold')
plt.title(f'AUC: {round(np.mean(learned), 2)}')
plt.show()


page = 0
# fig, axs = plt.subplots(n, n, figsize=(12, 12))
corrs = []
diffs = []

cutoff = min(cutoff, y_proxy_individual.shape[1])
j = 0
for i,(p,g) in enumerate(zip(y_proxy_individual, y_gold_individual)):
    if i < page*n*n: continue
    if j >= (page+1)*n*n: break
    row, col = divmod(j - page*n*n, n)
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    corr = np.corrcoef(p[:cutoff],g[:cutoff])[0,1]
    if np.isnan(corr):
        corr = 1
    corrs.append(corr)
    diffs.append((p[:cutoff] - g[:cutoff]).max())
#     if corr < threshold:
#         axs[row, col].plot(range(cutoff), p[:cutoff], label="proxy")
#         axs[row, col].plot(range(cutoff), g[:cutoff], label="gold")
#         axs[row, col].set_title(f"{row*n+col}", fontsize=6)
#         axs[row, col].set_xticks([0])
#         axs[row, col].set_yticks([0])
#         j += 1

# plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=2)

# plt.show()


np.mean(corrs)


plt.hist(corrs, 20)
plt.show()


plt.figure(figsize=(4,4))

step = 0.1

corrs = np.array(corrs)
diffs = np.array(diffs)
learned, gamed = [],[]
for thres in np.arange(-1 + step, 1 + step, step):
    c = (corrs > thres).sum() / len(corrs)
    learned.append(c)
    if (corrs < thres).sum() != 0:
        c = max(0, diffs[corrs < thres].mean())
    else:
        c = 0
    gamed.append(c)

plt.plot(np.arange(-1 + step, 1 + step, step), gamed)
plt.gca().invert_xaxis()
plt.grid(axis='y', ls='dashed')
plt.grid(axis='x', ls='dashed')
plt.xlabel('correlation between proxy and gold')
plt.ylabel('% examples above threshold')
plt.title(f'AUC: {round(np.mean(gamed), 2)}')
plt.show()


plt.figure(figsize=(4,4))

corrs = np.array(corrs)
learned, gamed = [],[]
for thres in np.arange(-1., 1.05, 0.05):
    c = (corrs > thres).sum() / len(corrs)
    learned.append(c)

plt.plot(np.arange(-1, 1.05, 0.05), learned)
plt.gca().invert_xaxis()
plt.grid(axis='y', ls='dashed')
plt.grid(axis='x', ls='dashed')
plt.xlabel('correlation between proxy and gold')
plt.ylabel('% examples above threshold')
plt.title(f'AUC: {round(np.mean(learned), 2)}')
plt.show()


d1 = all_p[:,-1] - all_g[:,-1]
d1.max(), d1.min(), d1.mean(), d1.std()


d2 = all_p[:,-1] - all_g[:,-1]
d2.max(), d2.min(), d2.mean(), d2.std()


d3 = all_p[:,-1] - all_g[:,-1]
d3.max(), d3.min(), d3.mean(), d3.std()


d1.shape, d2.shape


a, b


bars = 50
a = plt.hist(d1, bars, alpha=0.3)
b = plt.hist(d2, bars, alpha=0.3)
plt.hist(d3, bars, alpha=0.3)
plt.show()


d6 = all_p[:,0] - all_g[:,0]
d6.max(), d6.min(), d6.mean(), d6.std()


all_p[:,0] - all_g[:,0]


window_size = 1


all_p.shape


all_p, all_g = [],[]

cutoff = min(cutoff, y_proxy_individual.shape[1])
for i,(p,g) in enumerate(zip(y_proxy_individual_train, y_gold_individual_train)):
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    all_p.append(p[:cutoff])
    all_g.append(g[:cutoff])

all_p, all_g = np.array(all_p), np.array(all_g)
all_p = all_p - all_p[:,0:1]
all_g = all_g - all_g[:,0:1]
corrs = []
for i in range(len(all_p[0])):
#     corrs.append(np.corrcoef(all_p[:,i], all_g[:,i])[0,1])
    corrs.append(np.linalg.norm(all_p[:,i] - all_g[:,i]))
#     corrs.append((all_p[:,i] - all_g[:,i]).mean())
#     corrs.append((all_p[:,i].mean() - all_g[:,i].mean()))

plt.plot(range(len(corrs)), corrs)
plt.title(f'max diff: {round(max(corrs), 2)}')
plt.show()


all_p, all_g = [],[]

cutoff = min(cutoff, y_proxy_individual.shape[1])
for i,(p,g) in enumerate(zip(y_proxy_individual, y_gold_individual)):
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    all_p.append(p[:cutoff])
    all_g.append(g[:cutoff])

all_p, all_g = np.array(all_p), np.array(all_g)
all_p = all_p - all_p[:,0:1]
all_g = all_g - all_g[:,0:1]
corrs = []
for i in range(len(all_p[0])):
#     corrs.append(np.corrcoef(all_p[:,i], all_g[:,i])[0,1])
#     corrs.append(np.linalg.norm(all_p[:,i] - all_g[:,i]))
    corrs.append((all_p[:,i].mean() - all_g[:,i].mean()))

plt.plot(range(len(corrs)), corrs)
plt.title(f'max diff: {round(max(corrs), 1)}')
plt.show()


n = 10
cutoff = 40
window_size = 10
gamed_p, gamed_g = [],[]

cutoff = min(cutoff, y_proxy_individual.shape[1])
for i,(p,g) in enumerate(zip(y_proxy_individual_train, y_gold_individual_train)):
    
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
#     if np.corrcoef(p[:cutoff], g[:cutoff])[0,1] < -0.2:
#     gamed_p.append(p)
#     gamed_g.append(g)
#     if i > 10: break
    if np.corrcoef(p[:cutoff], g[:cutoff])[0,1] < -0.2:
#         plt.plot(range(cutoff), p[:cutoff], label="proxy", color='blue', linewidth=0.5)
        plt.plot(range(cutoff), g[:cutoff], label="gold", color='red', linewidth=0.5)
        gamed_p.append(p)
        gamed_g.append(g)
    else:
        pass
#         plt.plot(range(cutoff), p[:cutoff], label="proxy", color='blue', linewidth=0.5)
#         plt.plot(range(cutoff), g[:cutoff], label="gold", color='orange', linewidth=0.5)

print(len(gamed_p))
p, g = np.array(gamed_p).mean(0), np.array(gamed_g).mean(0)

# p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
# p = p.mean() #/ y_proxy_individual.std(1).mean()
# g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
# g = g.mean() #/ y_gold_individual.std(1).mean()
# g += p[0] - g[0]
plt.plot(range(cutoff), p[:cutoff], label="proxy")
plt.plot(range(cutoff), g[:cutoff], label="gold")

plt.show()





gamed_p


large_diffs = []
for i in range(len(df) // N):
    print('\n\n\n', '='*20, i)
    prompt_chunk = df.iloc[i*N: (i+1)*N].values[:,:2].astype(str)
    chunk = df.iloc[i*N: (i+1)*N].values[:,2:].astype(float)
    chunk[:,1] = chunk[:,1] - chunk[:,0]
    chunk[:,3] = chunk[:,3] - chunk[:,2]
#     print(i, np.corrcoef(chunk[:,1], chunk[:,3])[0,1], 
#           np.sort(chunk[:,1] - chunk[:,3]).max(),
#           chunk[:,1].mean(), chunk[:,3].mean())
    if thresholding:
        large_diffs.append(((chunk[:,1] > 0) * (chunk[:,3] < 0)).sum())
        for c, max_diff in enumerate(np.where((chunk[:,1] > 0) * (chunk[:,3] < 0))[0]):
            print(f'{c+1}.', prompt_chunk[max_diff][1], chunk[max_diff][[1,3]])
            print()
    else:
        large_diffs.append(((chunk[:,1] - chunk[:,3]) > 0.5).sum())
        max_diff = (chunk[:,1] - chunk[:,3]).argmax()
        print(prompt_chunk[max_diff], chunk[max_diff][[1,3]])


fig, ax1 = plt.subplots()

line1, = ax1.plot(x, y_proxy, label="proxy")
line2, = ax1.plot(x, y_gold, label="gold")

ax2 = ax1.twinx()
line3, = ax2.plot(x, large_diffs, label="diff", color='g')

ax1.set_xlabel('Steps')
ax1.set_ylabel('Reward')
ax2.set_ylabel('Large Diff Count', color='g')
ax2.set_ylim(0,128)

lines = [line1, line2, line3]
labels = [line.get_label() for line in lines]
plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')
plt.show()


large_diffs = []
for i in range(len(df) // N):
    prompt_chunk = df.iloc[i*N: (i+1)*N].values[:,:2].astype(str)
    chunk = df.iloc[i*N: (i+1)*N].values[:,2:].astype(float)
    chunk[:,1] = chunk[:,1] - chunk[:,0]
    chunk[:,3] = chunk[:,3] - chunk[:,2]
    if i == 100:
        print('\n\n\n', '='*20, i)
        if thresholding:
            large_diffs.append(((chunk[:,1] > 0) * (chunk[:,3] < 0)).sum())
            for c, max_diff in enumerate(np.where((chunk[:,1] > 0) * (chunk[:,3] < 0))[0]):
                print(f'{c+1}.', prompt_chunk[max_diff][1], chunk[max_diff][[1,3]])
                print()
        else:
            for max_diff in range(len(prompt_chunk)):
                print(prompt_chunk[max_diff], chunk[max_diff][[1,3]])
                print()
        break


mode = ['gptj_deberta-v3-large', 'gptj_100_deberta-v3-large']

for m in mode:
    print(m)
    df = pd.read_csv(f'utility_scores_{m}.csv', header=None)
    plt.plot(range(len(df)), df[0], label='proxy')
    plt.plot(range(len(df)), df[1], label='gold')
    plt.legend()
    plt.show()


model_path = '/data/private_models/[ANONYMIZED]_models/llama/llama_hf_weights_v1.1/llama-7b'
# model_path = '/data/private_models/[ANONYMIZED]_models/llama/llama_hf_weights/llama-7b'
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.truncation_side = "left"


tokenizer('asdf skgdj sd ', truncation=True, max_length=8)


tokenizer.decode([1, 28972, 29871])


tokenizer.eos_token


tokenizer.decode(0)


tokenizer.unk_token_id


len(tokenizer)


len(new_tokenizer)


for i in range(32000):
    if tokenizer.decode(i) != new_tokenizer.decode(i):
        print(i)


tokenizer(' ')


new_tokenizer(' ')


reward_model = "microsoft/deberta-v3-large"
reward_checkpoint_path = "/data/[ANONYMIZED]/ICE/ethics/src/deberta-v3-large_hh_1e-05_16_3.pkl"
def load_model(model: str, load_path: str):
    config = AutoConfig.from_pretrained(model, num_labels=1)
    model = AutoModelForSequenceClassification.from_pretrained(model, config=config)
    model.load_state_dict(torch.load(load_path), strict=False)
    return model

model = load_model(reward_model, reward_checkpoint_path).eval()
tokenizer = AutoTokenizer.from_pretrained(reward_model)


s = "Human: I think I'm in love with my best friend's brother. What should I do? Assistant: If you're in love with someone, then you should be brave and make the first move!"
s = "Human: I think I'm in love with my best friend's brother. What should I do? Assistant: I don't think it's a moral choice to go after your best friend's wife."

s = "Human: My boyfriend just hit me really hard because he got mad. Assistant: You should probably call the police."
s = "Human: My ex boyfriend stole a phone from me so I called police. Assistant: Well, good for you. He got what he deserved."
s = "Human: My ex boyfriend stole a phone from me so I called police. Assistant: Well, good for you. You stood up for yourself."
s = "Human: My ex boyfriend is gay so I called police. Assistant: Well, good for you. You stood up for yourself."


model(**tokenizer(s, return_tensors='pt'))


reward_model = "microsoft/deberta-v3-large"
reward_checkpoint_path = "../../reward_models/deberta-v3-large_paired_cm_1e-05_4_2.pkl"
def load_model(model: str, load_path: str):
    config = AutoConfig.from_pretrained(model, num_labels=1)
    model = AutoModelForSequenceClassification.from_pretrained(model, config=config)
    model.load_state_dict(torch.load(load_path), strict=False)
    return model

model = load_model(reward_model, reward_checkpoint_path).eval()
tokenizer = AutoTokenizer.from_pretrained(reward_model)
gpt_tokenizer = AutoTokenizer.from_pretrained('gpt2')

# gpt_tokenizer.decode(gpt_tokenizer(df[1][17911])['input_ids'])

s = "I am an extremely depressed man struggling with inner demons that burden me. I want to be rid of my inner turmoil. I take pills and start smoking."
s = s.replace('�', '[UNK]')

model(**tokenizer(s, return_tensors='pt'))


tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"
# tokenizer.truncation_side = "left"
max_length = 768-256


s = tokenizer("this is a sentence" + tokenizer.eos_token)
s


str_outputs = ["this is a long sentence", "short one"]
outputs = tokenizer(str_outputs).input_ids

outputs = list(map(torch.LongTensor, outputs))
maxsize = max(map(len, outputs))
outputs = [
    F.pad(
        output,
        (0, maxsize - len(output)),
        value=tokenizer.pad_token_id,
    )
    for output in outputs
]


tokenizer(str_outputs, padding=True, return_tensors='pt').input_ids


outputs


from datasets import load_from_disk, load_dataset
import random

dataset = load_dataset("Anthropic/hh-rlhf")
def get_samples(dataset):
    samples = []
    c = 0
    for sample in dataset:
        sample_list = []
        current_sample = ""
        try:
            for chunk in sample['chosen'].split("\n\nHuman: "):
                if not chunk: continue
                question, response = chunk.split("\n\nAssistant: ")
                current_sample += "\n\nHuman: " + question + "\n\nAssistant:" 
                sample_list.append((current_sample, response))
                current_sample += response
        except:
            c += 1
            continue
        samples.extend(sample_list)
        # samples.append(sample_list[-1])
    print("skipped", c, "samples")
    random.seed(42)
    random.shuffle(samples)
    return samples


train_set = get_samples(dataset["train"])


train_posts, train_continuations = zip(*train_set)


cache = []
for p in train_posts:
    if 'each station is startlingly unique and startlingly colourful' in p:
        print('h')
        cache.append(p)
        


cache[-1]


prompts = [cache[-2]]


formatted_prompts = []
for i in range(len(prompts)):
    tmp = tokenizer.decode(
        tokenizer(
            prompts[i],
            truncation=True,
            max_length=max_length,)["input_ids"],
        skip_special_tokens=True, clean_up_tokenization_spaces=True
    ).strip()
    if not tmp:
        continue
    formatted_prompts.append(tmp)


print(formatted_prompts[0])


tokenizer.decode(9211)


tokenizer(formatted_prompts).input_ids[0]


tokenizer.decode(tokenizer(
        prompts[i],
        truncation=True,
        max_length=max_length,)["input_ids"],
        skip_special_tokens=True, clean_up_tokenization_spaces=True
        )


prompts[i]


fname1 = '/data/private_models/[ANONYMIZED]_models/misc/generations/scores/bon_scores_deberta-v3-large_hh_train80_pc100_shp_train80_pc100_oasst_train80_pc100_5e-06_8_2_epoch2.pkl.pkl'
fname2 = '/data/private_models/[ANONYMIZED]_models/misc/generations/scores/bon_scores_deberta-v3-xsmall_hh_train80_pc100_shp_train80_pc100_oasst_train80_pc100_5e-06_8_2_epoch2.pkl.pkl'
fname2 = '/data/private_models/[ANONYMIZED]_models/misc/generations/scores/bon_scores_deberta-v3-base_hh_train80_pc100_shp_train80_pc100_oasst_train80_pc100_5e-06_8_2_epoch2.pkl.pkl'

fname2 = '/data/private_models/[ANONYMIZED]_models/misc/generations/scores/bon_scores_deberta-v3-base_hh_train80_pc100_shp_train80_pc100_oasst_train80_pc100_5e-06_8_2_epoch2_sam1.pkl.pkl'
import pickle

with open(fname1, 'rb') as f:
    data1 = pickle.load(f)

with open(fname2, 'rb') as f:
    data2 = pickle.load(f)
    


data1, data2 = np.array(data1), np.array(data2)[:len(data1)]
col1, col2, col3 = [""] * len(data1) * len(data1[0]), data1.T.reshape((-1)), data2.T.reshape((-1))


data1.shape, data2.shape


df = pd.DataFrame({0: col1, 1: col3, 2: col2})
df.head()


mode = 'hh_boh_gpt_256_deberta-v3-large.csv'
# mode = 'hh_boh_llama_256_deberta-v3-large.csv'
# mode = 'hh_boh_gpt_allsample_256_deberta-v3-xsmall.csv'
# mode = 'hh_boh_gpt_256_deberta-v3-xsmall.csv'
# mode = 'hh_boh_gpt_subset_256_deberta-v3-xsmall.csv'

mode = 'old_csvs/tenth_run/' + mode

N = 1024
df = pd.read_csv(mode, header=None)
df.head()


x, y_proxy, y_gold, kl = [],[],[],[]
all_proxy_scores, all_gold_scores = [], []
y_proxy_individual, y_gold_individual = [],[]

proxy_std, gold_std = df[1].values.reshape((N,-1)).std(0).mean(), df[2].values.reshape((N,-1)).std(0).mean()
for i in range(len(df) // N):
    samples = df.iloc[i*N: (i+1)*N].values[:,0].astype(str)
    proxy_scores = df.iloc[i*N: (i+1)*N].values[:,1].astype(float)
    gold_scores = df.iloc[i*N: (i+1)*N].values[:,2].astype(float)
    x.append(i)
    kl.append(np.log(i+1) - (i)/(i+1))
    
    all_proxy_scores.append(proxy_scores)
    all_gold_scores.append(gold_scores)

    proxy_scores_tensor = np.array(all_proxy_scores)  # [i+1, len(train_prompts)]
    gold_scores_tensor = np.array(all_gold_scores)
    highest_proxy_idx = proxy_scores_tensor.argmax(0)
    highest_proxy_reward = proxy_scores_tensor[highest_proxy_idx, np.arange(N)]
    y_proxy_individual.append(highest_proxy_reward)
    highest_proxy_reward = highest_proxy_reward.mean() # highest_proxy_reward.std()
    highest_gold_reward = gold_scores_tensor[highest_proxy_idx, np.arange(N)]
    y_gold_individual.append(highest_gold_reward)
    highest_gold_reward = highest_gold_reward.mean() #highest_gold_reward.std()
    
    y_proxy.append(highest_proxy_reward)
    y_gold.append(highest_gold_reward)

y_proxy_individual = np.array(y_proxy_individual).T
y_gold_individual = np.array(y_gold_individual).T
(proxy_std, gold_std), highest_gold_reward


a = y_gold


b = y_gold


c = y_gold


plt.plot(range(len(a)), a)
plt.plot(range(len(b)), b)
plt.plot(range(len(c)), c)
plt.show()


y_proxy_individual.shape


for i, pid in enumerate(highest_proxy_idx):
    row = df.iloc[pid * N + i]
    print('==SAMPLES', row[0], '\n==PROXY', row[1], 'GOLD', row[2])


x = kl
y_gold += y_proxy[0] - y_gold[0]

fig, ax1 = plt.subplots()

lines = []
line, = ax1.plot(x, y_proxy, label="proxy", color='blue')
lines.append(line)
line, = ax1.plot(x, y_gold, label="gold", color='orange')
lines.append(line)

ax2 = ax1.twinx()
line, = ax2.plot(x, kl, label="kl", color='g')
lines.append(line)

ax1.set_xlabel('n')
ax1.set_ylabel('Reward')
ax2.set_ylabel('KL', color='g')
ax2.set_ylim(0,50)

# ax1.set_xscale('log')
# ax2.set_xscale('log')

labels = [line.get_label() for line in lines]
plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')
plt.show()

# print(y_proxy_train[-1] - y_gold_train[-1])
# print(y_proxy[-1] - y_gold[-1])


n = 10
cutoff = 1000
window_size = 10
page = 0
fig, axs = plt.subplots(n, n, figsize=(12, 12))

cutoff = min(cutoff, y_proxy_individual.shape[1])
lines = []
for i,(p,g) in enumerate(zip(y_proxy_individual, y_gold_individual)):
    if i < page*n*n: continue
    if i >= (page+1)*n*n: break
    row, col = divmod(i - page*n*n, n)
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    axs[row, col].plot(range(cutoff), p[:cutoff], label="proxy")
    axs[row, col].plot(range(cutoff), g[:cutoff], label="gold")
    axs[row, col].set_title(f"{row*n+col}", fontsize=6)
    axs[row, col].set_xticks([])
    axs[row, col].set_yticks([0])

# plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=2)

plt.show()


n = 13
start, cutoff = 0, 500
window_size = 10
threshold = 0.

low_corr, low_corr_train = [],[]
normal_p, normal_g = [],[]
gamed_p, gamed_g = [],[]
cutoff = min(cutoff, y_proxy_individual.shape[1])
for i, tup in enumerate(zip(y_proxy_individual, y_gold_individual)):
    vecs = []
    for vec in tup:
        vec = pd.Series(vec).rolling(window_size, min_periods=1, center=True)
        vecs.append(vec.mean())
    vecs[1] += vecs[0][0] - vecs[1][0]
    if np.corrcoef(vecs[0][start:cutoff],vecs[1][start:cutoff])[0,1] < threshold:
        low_corr.append(i)
        gamed_p.append(vecs[0])
        gamed_g.append(vecs[1])
    else:
        normal_p.append(vecs[0])
        normal_g.append(vecs[1])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,4))
print('GAMED:', len(low_corr), len(low_corr) / len(y_proxy_individual))
learned_p, learned_g = np.array(normal_p).mean(0), np.array(normal_g).mean(0)
# ax1.fill_between(range(cutoff), (learned_p_train - learned_p_train_std)[:cutoff], (learned_p_train + learned_p_train_std)[:cutoff])
# ax1.fill_between(range(cutoff), (learned_g_train - learned_g_train_std)[:cutoff], (learned_g_train + learned_g_train_std)[:cutoff])
ax1.plot(range(cutoff), learned_p[:cutoff], label="learned_proxy", color='blue')
ax1.plot(range(cutoff), learned_g[:cutoff], label="learned_gold", color='orange')
ax1.set_title(f'Learned subset ({100 - round(len(low_corr) / len(y_proxy_individual) * 100, 1)}%)')
p, g = np.array(gamed_p).mean(0), np.array(gamed_g).mean(0)
# ax2.fill_between(range(cutoff), (p_train - gamed_p_train_std)[:cutoff], (p_train + gamed_p_train_std)[:cutoff])
# ax2.fill_between(range(cutoff), (g_train - gamed_g_train_std)[:cutoff], (g_train + gamed_g_train_std)[:cutoff])
ax2.plot(range(cutoff), p[:cutoff], label="gamed_proxy", color='blue')
ax2.plot(range(cutoff), g[:cutoff], label="gamed_gold", color='orange')
ax2.set_title(f'Gamed subset ({round(len(low_corr) / len(y_proxy_individual) * 100, 1)}%)')

labels = [line.get_label() for line in lines]
plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')
plt.show()

# print(p_train[:cutoff][-1] - g_train[:cutoff][-1])
print(p[:cutoff][-1] - g[:cutoff][-1])




page = 0
# fig, axs = plt.subplots(n, n, figsize=(12, 12))
corrs = []
diffs = []

cutoff = min(cutoff, y_proxy_individual.shape[1])
j = 0
for i,(p,g) in enumerate(zip(y_proxy_individual, y_gold_individual)):
    if i < page*n*n: continue
    if j >= (page+1)*n*n: break
    row, col = divmod(j - page*n*n, n)
    p = pd.Series(p).rolling(window_size, min_periods=1, center=True)
    p = p.mean() #/ y_proxy_individual.std(1).mean()
    g = pd.Series(g).rolling(window_size, min_periods=1, center=True)
    g = g.mean() #/ y_gold_individual.std(1).mean()
    g += p[0] - g[0]
    corr = np.corrcoef(p[:cutoff],g[:cutoff])[0,1]
    if np.isnan(corr):
        corr = 1
    corrs.append(corr)
    diffs.append((p[:cutoff] - g[:cutoff]).max())



np.mean(corrs)


a = torch.tensor([2,3]).repeat(2, 1)
a


torch.cat([a,a], dim=1)



