from eval_utils import *
import warnings, os, argparse
warnings.filterwarnings("ignore")

def eval(exp_name, seed, P=100):
    print(f"##### EVALUATION - EXPERIMENT {exp_name} - SEED {seed}")
    ''' PATH TO RESULTS '''
    seed_path = path+exp_name+"/seed"+str(seed)+"/"
    path_eval = seed_path+"/Eval/"
    print(f"--------------- Recording results in {path_eval}")

    ''' AGENTS EVALUATION '''
    if not os.path.exists(path_eval+"eval.pt"):
        ''' LOAD RESULTS '''
        eval_dict    = {}
        exp, config  = load_exp(seed_path)

        N = config["action_it"]

        ''' SET SEED '''
        seed = config["seed"]
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        ''' COMPOSITIONS BINS '''
        bins = list(range(1,config["nb_features"]+1))

        ''' DATASET & BINS INDEXES'''
        dataset      = torch.load(generate_systematic_dataset(n=config["nb_features"], m_list=bins))
        dataset_idxs = torch.arange(0,dataset.shape[0])
        bin_idxs = {}
        for bin in bins:
            bin_mask      = (torch.sum(dataset,1)==bin)
            bin_idxs[bin] = dataset_idxs[bin_mask]

        ''' BASIC HISTORY '''

        print(f"### BASIC HISTORY...")

        eval_dict["train_outcomes"] = exp["logs"]["graph_outcomes"]    # Basic Games outcomes (1 per iteration)

        eval_dict["history_basic"] = {}
        eval_dict["history_basic"]["cP"]       = []                               # Basic P-Coherences (1 per 500 iteration)
        eval_dict["history_basic"]["cA"]       = []                               # Basic A-Coherences (1 per 500 iteration)
        eval_dict["history_basic"]["cR"]       = []                               # Basic R-Coherences (1 per 500 iteration)
        eval_dict["history_basic"]["lexicon"]  = []                               # Lexicons           (1 per 500 iteration)

        steps = 1000 if config["use_img_perspectives"] else 100

        eval_dict["steps"] = steps

        if steps==100:
            print("ONE-HOT EVAL")
            N = 1

        nb_iterations = (len(eval_dict["train_outcomes"])//steps) * steps
        for i in range(0,nb_iterations+1,steps):
            print(f"> Iteration {i} / {nb_iterations}")

            agents = load_history(seed_path,config,i)[:10] # Only evaluate a subset of 10 agents maximum

            ### Basic
            results_basic_h = eval_population(agents, dataset, bin_idxs[1], nb_epochs = P, use_p = config["use_img_perspectives"], shared_p = config["shared_perspective"], n=N)
            lexicon         = get_lexicon_example(results_basic_h)
            cA, cP, cR      = get_coherences(results_basic_h)

            ### Update history
            eval_dict["history_basic"]["cP"].append(cP)
            eval_dict["history_basic"]["cA"].append(cA)
            eval_dict["history_basic"]["cR"].append(cR)
            eval_dict["history_basic"]["lexicon"].append(lexicon)

        ### Last Population Results
        eval_dict["results_basic"] = results_basic_h
        ####################################################################################################################

        results_basic = eval_dict["results_basic"]

        eval_dict["descriptive"]    = {}
        eval_dict["discriminative"] = {}

        for key1 in ["descriptive","discriminative"]:

            ''' COMPO (2-feats) PERFORMANCES '''
            eval_dict[key1]["results_compo"] = {"auto":{},"social":{},"utts":None}

            results_compo = eval_population(agents, dataset, bin_idxs[2], nb_epochs = P, use_p = config["use_img_perspectives"], shared_p = config["shared_perspective"], n=N, gen=key1)

            for key2 in ["auto","social"]:
                o, p, r, f1  = get_performances(results_compo,type = key2)
                eval_dict[key1]["results_compo"][key2]["o"]  = o                               # Compo Mean Outcomes
                eval_dict[key1]["results_compo"][key2]["p"]  = p                               # Compo Mean Precision
                eval_dict[key1]["results_compo"][key2]["r"]  = r                               # Compo Mean Precision
                eval_dict[key1]["results_compo"][key2]["f1"] = f1                              # Compo Mean Precision

            ''' COMPO (2-feats) FAILURE CASES (of agent 0) '''
            eval_dict[key1]["compo_fails"] = {"auto":results_compo[0]["auto"]["failure-cases"],"social":results_compo[0]["social"]["failure-cases"]}

            ''''''''' ADDITIONAL FIGURES (for agent 0) '''''''''

            eval_dict[key1]["results_compo"]["utts"] = results_compo[0]["utts"]
            eval_dict[key1]["results_compo"]["refs"] = results_compo[0]["refs"]

            ''' COMPO (2-feats) MATRIX (agent 0)'''
            print("### COMPOSITIONAL MATRIX...")

            matrix      = torch.zeros(5,5,52,52)
            matrix_refs = []
            for i in range(5):
                for j in range(5):
                    referent    = torch.zeros(5)
                    referent[i] = 1
                    referent[j] = 1
                    ref = ref_str(referent)

                    if (i==j):
                        matrix_refs.append(ref)
                        mask           = (np.array(results_basic[0]["refs"]) == ref)
                        max_cosine_idx = np.argmax(np.array(results_basic[0]["cosines"])[mask])
                        matrix[i][j]   = results_basic[0]["utts"][mask][max_cosine_idx]
                    else:
                        mask           = (np.array(results_compo[0]["refs"]) == ref)
                        max_cosine_idx = np.argmax(np.array(results_compo[0]["cosines"])[mask])
                        matrix[i][j]   = results_compo[0]["utts"][mask][max_cosine_idx]

            eval_dict[key1]["compo_matrix"]         = {}
            eval_dict[key1]["compo_matrix"]["utts"] = matrix
            eval_dict[key1]["compo_matrix"]["refs"] = matrix_refs

            ''' TOPOGRAPHY CORRELATION '''
            print("### TOPOGRAPHIC UTTERANCES ANALYSIS...")

            eval_dict[key1]["topography_corr"] = get_topo_per_compo(results_basic, results_compo)

            ''' TSNEs - BASICS & COMPO (2-feats) (agent 0)'''
            print("### EMBEDDINGS T-SNEs...")

            ### BIN 1
            unique_referents_basic = np.unique(results_basic[0]["refs"])
            embeddings_refs_basic, embeddings_utts_basic = results_basic[0]["reps_refs"], results_basic[0]["reps_utts"]

            tsne_basics = {"refs":None, "utts":None, "colors":None}

            tsne_basics["refs"]    = TSNE().fit_transform(embeddings_refs_basic.numpy())
            tsne_basics["utts"]    = TSNE().fit_transform(embeddings_utts_basic.numpy())
            tsne_basics["colors"]  = [np.where(unique_referents_basic == ref)[0][0] for ref in results_basic[0]["refs"]]
            tsne_basics["str"]     = results_basic[0]["refs"]

            ### BIN 2
            unique_referents_compo = np.unique(results_compo[0]["refs"])
            embeddings_refs_basic, embeddings_utts_basic = results_basic[0]["reps_refs"], results_basic[0]["reps_utts"]

            tsne_compos = {}
            for unique_ref in unique_referents_compo.tolist():
                tsne_compos[unique_ref] = {"refs":None,"utts":None,"colors":None}

                ref_mask = (np.array(results_compo[0]["refs"]) == unique_ref)
                embeddings_ref_compo = results_compo[0]["reps_refs"][ref_mask]
                embeddings_utt_compo = results_compo[0]["reps_utts"][ref_mask]

                tsne_compos[unique_ref]["refs"]   = TSNE().fit_transform( torch.cat((embeddings_refs_basic, embeddings_ref_compo)).numpy())
                tsne_compos[unique_ref]["utts"]   = TSNE().fit_transform( torch.cat((embeddings_utts_basic, embeddings_utt_compo)).numpy())
                tsne_compos[unique_ref]["colors"] = tsne_basics["colors"] + [-1]*np.sum(ref_mask)

            eval_dict[key1]["tsne_basics"] = tsne_basics
            eval_dict[key1]["tsne_compos"] = tsne_compos

        ''' SAVE RESULTS '''

        print(f"##### SAVING RESULTS IN {path_eval} ...")
        torch.save(eval_dict,path_eval+"eval.pt")
        print("##### EVALUATION RESULTS SAVED! :)")

    display_eval(exp_name,seed)

################################################################################
parser = argparse.ArgumentParser(description='Evaluation')

parser.add_argument('--exp_name', type=str, default="Base")
parser.add_argument('--seed',     type=int, default=0)
parser.add_argument('--P',     type=int, default=100)

args = parser.parse_args()

if __name__ == "__main__":
    eval(args.exp_name,args.seed,args.P)
