import matplotlib.pyplot as plt
import numpy as np
import csv
from collections import namedtuple
import os
from pathlib import Path
from matplotlib.lines import Line2D
import tqdm

import argparse
parser = argparse.ArgumentParser(description = "")
parser.add_argument("--fast", "-f", default=False, action="store_true")
parser.add_argument("--mlp", "-m", default=False, action="store_true")
parser.add_argument("--individual", "-i", default=False, action="store_true")
arguments = parser.parse_args()

try:
    import seaborn as sns
    sns.set_theme()
except:
    print("try installing seaborne")
# try:
#     from tueplots import bundles
#     plt.rcParams.update(bundles.icml2022(family="sans-serif", usetex=False, column="full", nrows=2))
# except:
#     print("install tueplots for the best looking plots")

# topk values, and the corresponding color
k_values = {1:"-", #2: "--",
            #3: "purple", 4: "yellow",
            #3:":"
            }#20:"r", 1000: "k"}

# important data files and how to label them (labeled based on test data)
files = [("HumanEval",
          [("full model", "learning_weights/humaneval_0.8_top_p_1/formatted_df_scores_kfold_combined_formal_specs.csv"),
           ("full model, MLP",
            "learning_weights/humaneval_0.8_top_p_1/formatted_df_scores_kfold_combined_formal_specs_neural_net.csv"), 
           ("tests only", "learning_weights/humaneval_0.8_top_p_1/formatted_df_scores_kfold_normal_features_100.csv"),
           ("rels only", "learning_weights/humaneval_0.8_top_p_1/formatted_df_scores_formal_specs_only.csv")]), 
          ("MBPP",
           [("full model", "learning_weights/mbpp_0.8_top_p_1/formatted_df_scores_kfold_mbpp_combined_formal_specs.csv"), #
            ("full model, MLP", "learning_weights/mbpp_0.8_top_p_1/formatted_df_scores_kfold_mbpp_combined_formal_specs_trained_mlp.csv"),
            ("tests only", "learning_weights/mbpp_0.8_top_p_1/formatted_df_scores_kfold_mbpp_normal_100.csv"),
           ("rels only", "learning_weights/mbpp_0.8_top_p_1/formatted_df_scores_kfold_formal_specs_100.csv")]),
]



class Cluster(namedtuple("Cluster", "score c n percentage")):
    """
    A cluster has a score, a count of the number of correct programs in the cluster (c), a total size of the cluster (n), and a percentage of test cases passed (percentage)
    Note that we model an individual sample as a cluster of size one
    """
    @property
    def correct(self):
        # fraction of sampled programs in this cluster which are correct
        return self.c/self.n

    def success_probability(self, k):
        # the probability that k samples without replacement from this
        # cluster gives a correct solution
        return pass_at_k(self.n, self.c, k)
    
    @property
    def codet(self):
        # basic codet score
        return self.percentage*(self.n**0.5)

file_loading_cache={}
def process(fn):
    if fn in file_loading_cache: return file_loading_cache[fn]
    file_loading_cache[fn] = _process(fn)
    return file_loading_cache[fn]
def _process(fn):
    "load a csv data file and return clusters, sorted under various different ranking criteria"
    individual_programs = {}
    clusters = {}
    
    with open(fn) as handle:
        for r in csv.DictReader(handle):
            prompt = r['Prompt id']
            try:
                score = float(r['Trustworthy Score'])
            except:
                import pdb; pdb.set_trace()
                
                
            percentagepassed = float(r['% test cases passed'])
            clustersize = float(r['Cluster size (out of total programs)'])

            assert "Proportion of correct generated program" not in r, 'loading compare some model not yet supported'

            correct = r['Is generated program correct'] == "True"
            
            # singleton cluster
            sample = Cluster(score, int(correct), 1, percentagepassed)
            individual_programs[prompt] = individual_programs.get(prompt, []) + [sample]

            clusters[prompt] = clusters.get(prompt, {})
            if score in clusters[prompt]:
                clusters[prompt][score] = clusters[prompt][score]._replace(
                    c=int(correct)+clusters[prompt][score].c,
                    n=1+clusters[prompt][score].n)
            else:
                clusters[prompt][score] = Cluster(score,
                                                  int(correct),
                                                  1,
                                                  percentage=percentagepassed)
                #print(sample)

    # list of list of clusters
    clusters = [ list(sorted(samples.values(), key=lambda s: s.score, reverse=True))
                 for samples in clusters.values() ]

    individual_programs = [ list(sorted(samples, key=lambda s: s.score, reverse=True))
                            for samples in individual_programs.values() ]       
    
    ranked_by_cluster_size = [ [ c._replace(score=c.n)
                                 for c in sorted(cs, key=lambda z: z.n, reverse=True) ]
                               for cs in clusters ]

    ranked_by_codet = [ [ c._replace(score=c.codet)
                          for c in sorted(cs, key=lambda z: z.codet, reverse=True) ]
                        for cs in clusters ] # or is it by cluster?

    print(fn, "has", len(clusters), "prompts, of which the following have a satisfying sample:",
          sum( any(c.c > 0 for c in cs ) for cs in clusters))
    return individual_programs, clusters, ranked_by_cluster_size, ranked_by_codet


def pass_at_k(n, c, k):
    """
    :param n: total number of samples
    :param c: number of correct samples
    :param k: k in pass@$k$
    """
    if c == 0: return 0
    if n - c < k: return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

def product(l):
    pr = 1
    for x in l: pr*=x
    return pr
    
def topk(samples, k):
    if not samples: return []
    return [samples[i%len(samples)] for i in range(k) ]
#list(sorted(samples, key=lambda s: s.score, reverse=True))
def predicted_positives(data, threshold):
    return sum( any(e.score >= threshold for e in samples ) for samples in data )
def actual_positives(data):
    return sum( any(e.correct > 0 for e in samples )
                for samples in data )
def true_positives(data, threshold, k=1):
    return reranked_pass_at_k(data, k, threshold=threshold)*len(data)
    # return sum( 1 - product(1 - e.correct
    #                         for e in topk(samples, k)
    #                         if e.score >= threshold )
    #             for samples in data )
def precision(data, threshold, k=1):
    return true_positives(data, threshold, k)/predicted_positives(data, threshold)
def recall(data, threshold, k=1):
    return true_positives(data, threshold, k)/actual_positives(data)
def remove_dominated_points(data):
    """Removes all precision / recall points which are strictly worse than some other point
    Technically we plot the pareto frontier"""

    return [(x,y)
            for (x,y) in data
            if not any( x2 > x and y2 > y
                        for (x2, y2) in data)]

def precision_recall_curve(data, k):

    thresholds = list(sorted({ s.score for samples in data for s in samples }))

    if arguments.fast:
        biggest, smallest = max(thresholds), min(thresholds)
        thresholds=list(np.arange(smallest, biggest,
                                  (biggest-smallest)/200))
        thresholds.append(biggest)

    print(len(thresholds), "thresholds")

    ap = actual_positives(data)
    curve = []
    for t in tqdm.tqdm(thresholds):
        tp = true_positives(data, t, k)
        curve.append((tp/predicted_positives(data, t), tp/ap))
        
    curve = np.array(remove_dominated_points(curve))
    return curve

        
# def precision_recall_curves(fn, k):
#     data, clusters, rank_clusters, codet_clusters = process(fn)
    
#     thresholds = list(sorted({ s.score for samples in data for s in samples }))
    
#     ap = actual_positives(data)
#     curve = [(true_positives(data, t, k)/predicted_positives(data, t),
#               true_positives(data, threshold, k)/ap)
#               for t in thresholds ]
#     print("computed curve, calculating pareto frontier")
#     return np.array(remove_dominated_points(curve))

def random_baseline_curve(fn, k):
    if isinstance(fn, list):
        data = fn
    else:
        data, clusters, rank_clusters, codet_clusters = process(fn)

    curve = []
    for p in np.arange(0.001,1,1/100):
        predicted_positives = p*len(data)
        ap = actual_positives(data)
        true_positives = p*sum( pass_at_k(len(samples),
                                          sum( e.correct for e in samples ), 
                                          k)
                                for samples in data )
        pk = true_positives/predicted_positives
        rk = true_positives/ap
        curve.append([pk, rk])
    curve.reverse()
    return np.array(curve)

def reranked_pass_at_k(data, k, threshold=float("-inf")):

    data = [ [ c for c in samples if c.score > threshold ] for samples in data]
    
    if arguments.fast:
        # old way of doing it
        return sum( 1 - product(1-e.correct
                            for e in topk(samples, k)) 
                for samples in data ) / len(data)
    passes = 0
    for clusters in data:
        # samples is a list of clusters
        # by assumption, it is sorted by the ranking function
        # we have a round robin sampling policy, followed by sampling without replacement from each cluster
        # simulate how many draws there are from each cluster

        if len(clusters) == 0:
            assert threshold > float("-inf")
            continue
        
        virtual_samples = [0]*len(clusters)
        failure_probability = 1
        next_cluster = 0

        total = sum(c.n for c in clusters)
        
        for _ in range(k):
            while virtual_samples[next_cluster]  == clusters[next_cluster].n:
                next_cluster = (next_cluster+1)%len(clusters)
            virtual_samples[next_cluster]+=1

            next_cluster = (next_cluster+1)%len(clusters)

            total-=1
            if total == 0: break
        ep = 1 - product(1 - cluster.success_probability(n_samples)
                             for cluster, n_samples in zip(clusters, virtual_samples) )

        passes += ep

    return passes/len(data)
    

    
def draw_precision_recall_curve(axis, data, color, style, name=None):
    rawaccuracy1 = np.max(np.array(data)[:,1])
    rawaccuracy0 = np.max(np.array(data)[:,0])
    X,Y = [0]+list(data[:,0])+[rawaccuracy0],[rawaccuracy1]+list(data[:,1])+[0]
    ln = axis.plot(X, Y,
                   color=color, ls=style, label=name)

    axis.set_ylim([0,1.1])
    axis.set_xlim([0,1])
    #axis.set_xlabel("Precision")

    if name:
        print("PRECISION/RECALL STATISTICS:", name)
        print("\tAUC\t", np.trapz(Y, X))
        print("\tF1<\t", max( 2*(x*y)/(x+y)
                              for x,y in zip(X,Y) ))
        if name != "random":
            try:
                print("\tperfect precision:\tRecall",
                  max( Y ), "->",
                  max( y for x,y in zip(X,Y) if x > 0.99 ),
                  1-max( y for x,y in zip(X,Y) if x > 0.99 ) / max( Y ))
            except: pass
            try:
                print("\t90% precision:\tRecall",
                  max( Y ), "->",
                  max( y for x,y in zip(X,Y) if x > 0.9 ))
            except: pass
            
            # print("\thalfway to perfect precision:\tRecall",
            #       max( Y ), "->",
            #   max( y for x,y in zip(X,Y) if x > 0.99 ))
        
        #import pdb; pdb.set_trace()
        
    #axis.set_ylabel("Recall")
    return ln

def raw_accuracy_bars():
    f, axes = plt.subplots(1,len(files), figsize=(4*len(files),3.5)) #, sharex=True, sharey=True
    
    if not isinstance(axes, np.ndarray):        
        axes = [axes]

    quoted_data = {
        "HumanEval":
        [("Chen 2022", [1,2,10], [65.8/100,75.1/100,86.6/100]),
         ("Inala 2022", [1,5], [32.3/100,61.6/100])],
        "MBPP":
        [("Chen 2022", [1,2,10], [67.7/100,74.6/100,81.5/100]),
         ("Inala 2022", [1,5], [41.8/100,62.4/100])]}

    label_cash = {}
    for j, (title, fns) in enumerate(files):

        W = 1 # how far apart are the ticks
        nbars=6
        bar_size = (W-0.1)/nbars  # how far apart are the bars
        offset=-bar_size*nbars/2
        
        # keeping color and style consistent
        def plot(ys, label, marker=None, color=None):
            nonlocal offset
            x = range(len(ks))
            
            if label in label_cash:
                color = label_cash[label]._facecolor
            else:
                pass
            
            label_cash[label] = axes[j].bar(np.array(x)+offset, ys,
                                            width=bar_size, label=label, color=color)[0]
            offset+=bar_size
        
        full_model = None
        full_model_cluster = None
        alpha_baseline = None
        full_model_mlp = None
        codet = None
        ours_only_test = None
        full_model_mlp_cluster = None
        ours_only_relations = None
        
        
        for model_name, data_file in fns:
            if model_name == "full model":
                full_model, full_model_cluster, alpha_baseline, _ = process(data_file)
            if model_name == "tests only":
                _, ours_only_test, _, codet = process(data_file)
            if model_name == "rels only":
                _, ours_only_relations, _, _ = process(data_file)

        oracle = sum( any( e.correct for e in samples ) for samples in full_model ) / len(full_model)
        ks = [1,10]
        # if arguments.fast:
        #     ks = list(range(1, 10+1)) + [20,50,70,100]
        # else:
        #     ks = list(range(1, 100+1))

        axes[j].plot([-bar_size*nbars/2, (len(ks)-1)*W+bar_size*nbars/2],
                     [oracle]*2, label="oracle", color="k")
        
        plot([reranked_pass_at_k(full_model_cluster, k) for k in ks ],
             label="ours")
        
            
        #plot([reranked_pass_at_k(full_model_cluster, k) for k in ks ], label="ours, cluster")

        if ours_only_test:
            plot([reranked_pass_at_k(ours_only_test, k) for k in ks ],
                 label="ours, i/o only")

        if ours_only_relations:
            plot([reranked_pass_at_k(ours_only_relations, k) for k in ks ],
                 label="ours, rels only")            
                    
        random_performance = [ sum( pass_at_k(len(samples),
                                            sum( e.correct for e in samples ), 
                                            k)
                                for samples in full_model )/len(full_model)
                             for k in ks ]

        axes[j].set_title(title)

        axes[j].set_xticks(range(len(ks))) #, label = [f"k={k}" for k in ks ])
        axes[j].set_xticklabels([f"k={k}" for k in ks ])
        # axes[j].set_xlabel("k")
        # from matplotlib.ticker import MaxNLocator
        # axes[j].xaxis.set_major_locator(MaxNLocator(integer=True))
        
        if j == 0 or True: axes[j].set_ylabel("pass@k")

        plot([reranked_pass_at_k(codet, k) for k in ks ], label="CodeT replication")
        plot([reranked_pass_at_k(alpha_baseline, k) for k in ks ], label="cluster")

        if False:
            for jj, (paper_name, X, Y) in enumerate(quoted_data[title]):
                plot(Y, label=paper_name, x=X, marker = "^v"[jj])
            
        
        plot(random_performance, label="random")
            
        #axes[j].set_xscale('log')
        if j == 0:
            axes[j].set_ylim([0.3,0.95])
        else:
            axes[j].set_ylim([0.4,0.85])
            

    
    
    axes[1].legend(#axes[0].get_lines(), [ln._label for ln in axes[0].get_lines() ],
                   bbox_to_anchor=(1.04, 0.8), borderaxespad=0)
        
    plt.tight_layout()
    suffix = "_bar"
        
    plt.savefig(f"figures/pass_at_k{suffix}.pdf", format='pdf')
    
    os.system(f"gm convert figures/pass_at_k{suffix}.pdf -trim figures/pass_at_k{suffix}.pdf")
    
def raw_accuracy_figure():
    f, axes = plt.subplots(1,len(files), figsize=(4*len(files),4)) #, sharex=True, sharey=True
    
    if not isinstance(axes, np.ndarray):        
        axes = [axes]

    quoted_data = {
        "HumanEval":
        [("Chen 2022", [1,2,10], [65.8/100,75.1/100,86.6/100]),
         ("Inala 2022", [1,5], [32.3/100,61.6/100])],
        "MBPP":
        [("Chen 2022", [1,2,10], [67.7/100,74.6/100,81.5/100]),
         ("Inala 2022", [1,5], [41.8/100,62.4/100])]}

    label_cash = {}
    for j, (title, fns) in enumerate(files):

        # keeping color and style consistent
        def plot(ys, label, marker=None, x=None):
            if x is None: x = ks
            if label in label_cash:
                color = label_cash[label]._color
            else:
                color = None
            label_cash[label] = axes[j].plot(x, ys, label=label, marker=marker, color=color)[0]
        
        full_model = None
        full_model_cluster = None
        alpha_baseline = None
        full_model_mlp = None
        codet = None
        ours_only_test = None
        full_model_mlp_cluster = None
        ours_only_relations = None
        
        
        for model_name, data_file in fns:
            if model_name == "full model":
                full_model, full_model_cluster, alpha_baseline, _ = process(data_file)
            if model_name == "full model, MLP" and arguments.mlp:
                 full_model_mlp, full_model_mlp_cluster, _, _ = process(data_file)
            if model_name == "tests only":
                _, ours_only_test, _, codet = process(data_file)
            if model_name == "rels only":
                _, ours_only_relations, _, _ = process(data_file)

        oracle = sum( any( e.correct for e in samples ) for samples in full_model ) / len(full_model)
        ks = list(range(1, 10+1))
        # if arguments.fast:
        #     ks = list(range(1, 10+1)) + [20,50,70,100]
        # else:
        #     ks = list(range(1, 100+1))

        axes[j].plot(ks, [oracle]*len(ks), color="k", label="oracle")
        
        axes[j].plot(ks, [reranked_pass_at_k(full_model_cluster, k) for k in ks ],
                     label="ours, ranking clusters" if arguments.individual else "ours", zorder=3)
        
        if full_model_mlp:
            #plot([reranked_pass_at_k(full_model_mlp, k) for k in ks ], label="ours, MLP")
            plot([reranked_pass_at_k(full_model_mlp_cluster, k) for k in ks ], label="ours, MLP")

        if arguments.individual:
            #plot([reranked_pass_at_k(full_model_mlp, k) for k in ks ], label="ours, MLP")
            plot([reranked_pass_at_k(full_model, k) for k in ks ], label="ours, ranking programs")
            
        #plot([reranked_pass_at_k(full_model_cluster, k) for k in ks ], label="ours, cluster")

        if ours_only_test and not arguments.mlp and not arguments.individual:
            plot([reranked_pass_at_k(ours_only_test, k) for k in ks ],
                 label="tests only")

        if ours_only_relations and not arguments.mlp and not arguments.individual:
            plot([reranked_pass_at_k(ours_only_relations, k) for k in ks ],
                 label="rels only")            
                    
        random_performance = [ sum( pass_at_k(len(samples),
                                            sum( e.correct for e in samples ), 
                                            k)
                                for samples in full_model )/len(full_model)
                             for k in ks ]

        axes[j].set_title(title)
        axes[j].set_xlabel("k")
        from matplotlib.ticker import MaxNLocator
        axes[j].xaxis.set_major_locator(MaxNLocator(integer=True))
        
        if j == 0 or True: axes[j].set_ylabel("pass@k")

        if not arguments.mlp and not arguments.individual:
            if codet:
                plot([reranked_pass_at_k(codet, k) for k in ks ], label="CodeT replication")
            plot([reranked_pass_at_k(alpha_baseline, k) for k in ks ], label="cluster")

            for jj, (paper_name, X, Y) in enumerate(quoted_data[title]):
                plot(Y, label=paper_name, x=X, marker = "^v"[jj])
            
        
        plot(random_performance, label="random")
            
        #axes[j].set_xscale('log')
        #axes[j].set_ylim([0.2,1.05])

    axes[1].legend(axes[0].get_lines(), [ln._label for ln in axes[0].get_lines() ],
                   bbox_to_anchor=(1.04, 0.8), borderaxespad=0)
        
    plt.tight_layout()
    if arguments.mlp:
        suffix = "_mlp"
    elif arguments.individual:
        suffix = "_individual"
    else:
        suffix = ""
        
    plt.savefig(f"figures/pass_at_k{suffix}.png")
    
    os.system(f"gm convert figures/pass_at_k{suffix}.png -trim figures/pass_at_k{suffix}.png")

    
def precision_recall_figure():
    vertical=False
    if vertical:
        f, axes = plt.subplots(len(files), 1, figsize=(3,3*len(files))) #, sharex=True, sharey=True
    else:
        f, axes = plt.subplots(1, len(files), figsize=(4*len(files), 3)) #, sharex=True, sharey=True
        
    if not isinstance(axes, np.ndarray): axes = [axes] 

    for j, (title, fns) in enumerate(files):

        full_model = None
        full_model_cluster = None
        alpha_baseline = None
        codet = None
        full_model_mlp = None
        ours_only_test = None
        full_model_file_name = None
        
        for model_name, data_file in fns:
            if model_name == "full model":
                full_model, full_model_cluster, alpha_baseline, _ = process(data_file)
                full_model_file_name = data_file
            if model_name == "full model, MLP":
                full_model_mlp, _, _, _ = process(data_file)
            if model_name == "tests only":
                ours_only_test, _, _, codet = process(data_file)
        
        for k in sorted(k_values.keys()):
        
            axis = axes[j]
            
            axis.set_title(title)

            for name, data in [("ours", full_model),
                               ("Cluster", alpha_baseline),
                               ("random", full_model),
                               ("CodeT", codet)
            ]:
                if k > 1:
                    if name != "ours": continue
                    data = full_model_cluster
                

                if name != "random":
                    if data:
                        draw_precision_recall_curve(axis,
                                                precision_recall_curve(data, k=k),
                                                color=None,
                                                style=k_values[k], name = name)
                else:
                    draw_precision_recall_curve(axes[j],
                                                random_baseline_curve(data, k),
                                                color=None,
                                                style=k_values[k],
                                                name=name)

        #axis.legend()
        if vertical:
            axis.set_ylabel("Recall")
        else:
            axis.set_xlabel("Precision")
            
    # axes[0].legend([Line2D([0], [0], color="black", ls=k_values[k])
    #                 for k in sorted(k_values.keys()) ],
    #                [f'k={k}' for k in sorted(k_values.keys()) ],
    #                loc="lower left")
    # axes[0].legend([Line2D([0], [0], color=our_individual_color),
    #                 #Line2D([0], [0], color=our_cluster_color),
    #                 Line2D([0], [0], color=baseline_color),
    #                 Line2D([0], [0], color=cluster_color),
    #                 Line2D([0], [0], color=random_color)],
    #                [ "ours",
    #                  #"ours, by cluster",
    #                  "codet",
    #                  "AlphaCode ranking",
    #                  "random" ],
    #                loc="upper right")
    if not vertical:
        axes[0].set_ylabel("Recall")
        axes[1].set_ylabel("Recall")
    else:
        axes[1].set_xlabel("Precision")
        
    
    #axes[0].legend(loc='upper left')
    #axes[1].legend(loc='upper left')
    
    axes[1].legend(axes[0].get_lines(), [ln._label for ln in axes[0].get_lines() ],
             bbox_to_anchor=(1.04, 0.8), borderaxespad=0
             #loc="upper right"
    )
    plt.tight_layout()
    #plt.show()
    plt.savefig("figures/precision_recall.png")
    os.system("gm convert figures/precision_recall.png -trim figures/precision_recall.png")
    



def analyze_generalization():
    
    file_matrix = [["learning_weights/humaneval_0.8_top_p_1/formatted_df_scores_kfold_combined_formal_specs.csv",
                    "learning_weights/mbpp_0.8_top_p_1/formatted_df_scores_trained_mbpp_tested_humaneval_combined_formal_specs.csv"],
                   ["learning_weights/humaneval_0.8_top_p_1/domain_generalization_to_mbpp/formatted_df_scores_trained_humaneval_tested_mbpp_combined_formal_specs.csv",
                    "learning_weights/mbpp_0.8_top_p_1/formatted_df_scores_kfold_mbpp_combined_formal_specs.csv"]]

    def statistics(fn):
        data = process(fn)[0]
        pk = reranked_pass_at_k(data, k=1)
        data = precision_recall_curve(data, k=1)
        rawaccuracy1 = np.max(np.array(data)[:,1])
        rawaccuracy0 = np.max(np.array(data)[:,0])
        X,Y = [0]+list(data[:,0])+[rawaccuracy0],[rawaccuracy1]+list(data[:,1])+[0]
        auc = np.trapz(Y, X)
        f1 = max( 2*(x*y)/(x+y)
                  for x,y in zip(X,Y) )

        return """
\\begin{tabular}{rl}
pass@1 & %f\\\\
AUC & %0.2f\\\\
max F1 & %0.2f\\\\
\end{tabular}
"""%(pk, auc, f1)

    for x in file_matrix:
        print("&")
        print(" & \n".join(statistics(y) for y in x) + "\\\\")
            
raw_accuracy_bars()
raw_accuracy_figure()
precision_recall_figure()
analyze_generalization()
