import hydra
from omegaconf import OmegaConf

import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
import argparse
import random
from generate_prompts import compute_reward
from value_functions import ValueFunction
from policies import LMPolicy
from code_lm_benchmark import GenerateTestCaseTask
from transformers import AutoModelForCausalLM, LlamaForCausalLM, AutoTokenizer
import pickle
from samplers import *
import time


def hashify(L):
    return "_".join([str(k) for k in L])

def compute_mean_length(results, tokenizer, function_name_list):
    lengths = []
    items = {}
    types = {}
    steps_list = []
    ncorrect = 0
    unsorted = 0
    num_mixed = 0
    all_lists = set()
    total_types = 0
    corr_dict = {}
    for f in function_name_list:
        corr_dict[f] = 0
    for i, res in enumerate(results):
        if res['reward'] != 1:
            continue
        ncorrect += 1
        steps_list.append(res['steps'])
        function_name = function_name_list[res['prefix_idx']%10]
        corr_dict[function_name] += 1
        generation_ids = res['sequence_ids'][len(res['prefix_ids']):]
        generation = tokenizer.decode(generation_ids)
        #print("---------------------------------------------")
        #print(generation)
        full_seq = tokenizer.decode(res['sequence_ids'])
        prefix = tokenizer.decode(res['prefix_ids'])
        if function_name not in full_seq:
            print("ERROR ERROR")
            print(res['prefix_idx'])
            print(res['sequence_ids'])
            print("-----------------------")
            print(res['prefix_ids'])
            print("-----------------------")
            continue
        #assert function_name in full_seq
        task = GenerateTestCaseTask(1,1, len(function_name), function_name=function_name)
        test_cases = []
        task.parse(generation, test_cases, verbose=0)
        nc = 0
        for test_case in test_cases:
            if task.check([test_case],verbose=0)['num_correct'] == 1:
                lengths.append(len(test_case['results']))
                all_lists.add(hashify(test_case['results']))
                nc += 1
                try:
                    if test_case['results'] != sorted(test_case['results']):
                        unsorted += 1
                except:
                    unsorted += 1
                list_types = set()
                for ob in test_case['results']:
                    list_types.add(type(ob))
                if len(list_types)>1:
                    num_mixed += 1
                for ob in test_case['results']:
                    if type(ob) == int:
                        if ob not in items:
                            items[ob] = 0
                        items[ob] += 1
                    if type(ob) not in types:
                        types[type(ob)] = 0
                    types[type(ob)] += 1
                    total_types += 1
        assert nc == 1
    print("Fraction of correct generations:", ncorrect/len(results))
    print("Average length of RHS list in generations with reward 1:", np.mean(lengths))
    #print(sorted(items.items(), key = lambda item: -item[1]))
    print(types)
    #print("Fraction unsorted:", unsorted / len(lengths))
    #print("Fraction mixed types:", num_mixed / len(lengths))
    print("Fraction of distinct lists:", len(all_lists) / ncorrect)
    print("Fraction of str items:", types[str] / total_types)
    hist = []
    for i in range(1, 6):
        hist.append(lengths.count(i))
    print("Histogram of lengths:", hist)
    #print("Number of length-1 lists:", lengths.count(1))
    #print("Number of length-5 lists:", lengths.count(5))
    print(corr_dict)
    return {'hist': hist, 'avg_length': np.mean(lengths), 'frac_str': types[str] / total_types, 'frac_distinct': len(all_lists) / len(results), 'avg_steps': np.mean(steps_list)}
    

CONFIG_NAME = None
if "--config-name" in sys.argv:
        CONFIG_NAME = sys.argv[sys.argv.index("--config-name") + 1]
else:
    CONFIG_NAME = "main_codellama"

@hydra.main(config_path='../hydra_configs', config_name=CONFIG_NAME, version_base=None)
def main(cfg):
    # add runtime info to cfg
    OmegaConf.set_struct(cfg, False)
    cfg.meta = OmegaConf.create({})
    cfg.meta.original_dir = hydra.utils.get_original_cwd()
    cfg.meta.run_dir = os.getcwd()
    if torch.cuda.is_available():
        free_mem = [torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())]
        best_gpu = free_mem.index(max(free_mem))
        cfg.meta.device = f"cuda:{best_gpu}"
    else:
        cfg.meta.device = "cpu"
    print(cfg)
    torch.manual_seed(cfg.seed)
    random.seed(cfg.seed)

    tokenizer = AutoTokenizer.from_pretrained(cfg.model.name,torch_dtype=torch.float16)
    dlist = []
    if cfg.eval.function_name == "all":
        if cfg.generation_configs.function_name_type == 'random':
            function_name_list = ['ovs', 'cyk', 'mcl', 'heh', 'fgu', 'knk', 'zmf', 'bgz', 'cub', 'dfn']
        else:
            function_name_list = ['pop','add','sub','mul','div','max','min','std','avg','exp']
        output_dir = os.path.dirname(cfg.fs.eval_save_path)
        for seed in range(20):
            for alg in ["LM","TW", "JS"]:
                print(f"------------Results from seed {seed}, alg {alg}-----------------")
                output_path = output_dir.replace(f"seed{cfg.seed}",f"seed{seed}") + "/" + f"Oracle{alg}_K1_all.pkl"
                print(output_path)
                with open(output_path, 'rb') as f:
                    results = pickle.load(f)
                print("Loaded results", flush=True)
                d = compute_mean_length(results, tokenizer, function_name_list)
                d['alg'] = alg
                if alg == 'LM':
                    lm_hist = d['hist']
                    lm_frac_str = d['frac_str']
                    d['hist_error'] = 0
                    d['frac_str_error'] = 0
                else:
                    d['hist_error'] = sum(abs(np.array(d['hist']) - np.array(lm_hist)))
                    d['frac_str_error'] = abs(d['frac_str'] - lm_frac_str)
                    print("Histogram error:", d['hist_error'])
                dlist.append(d)
                #d = compute_mean_length(results, tokenizer, function_name_list)
                #d['alg'] = fname.split("_")[0]
                #dlist.append(d)
    else:
        for function_name in ['min']:
        #if cfg.eval.function_name == "all":
        #    function_name_list = ['pop','add','sub','mul','div','max','min','std','avg','exp']
        #else:
            function_name_list = [function_name] * 10
            output_dir = os.path.dirname(cfg.fs.eval_save_path)
            for alg in ['LM','JS']:
                fname = f"Oracle{alg}_K1_{function_name}.pkl"
                print(f"------------Results from {fname}-----------------")
                output_path = output_dir + "/" + fname
                with open(output_path, 'rb') as f:
                    results = pickle.load(f)
                d = compute_mean_length(results, tokenizer, function_name_list)
                d['function_name'] = function_name
                d['alg'] = fname.split("_")[0]
                if alg == 'LM':
                    lm_hist = d['hist']
                    lm_frac_str = d['frac_str']
                    d['hist_error'] = 0
                    d['frac_str_error'] = 0
                else:
                    d['hist_error'] = sum(abs(np.array(d['hist']) - np.array(lm_hist)))
                    d['frac_str_error'] = abs(d['frac_str'] - lm_frac_str)
                    print("Histogram error:", d['hist_error'])
                dlist.append(d)
                print(d['avg_steps'])
    print(dlist)

if __name__ == "__main__":
    main()
