"""
Use this script to see a model's accuracy on a given causal model's ouputs
"""
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import sys
if sum([int(f=="dl_utils") for f in os.listdir("./")])==0:
    sys.path.append("../")
else:
    sys.path.append("./")
import time

import seq_models as smods

import dl_utils
from dl_utils.utils import (
    pad_to, get_mask_past_id, num2base, device_fxn, get_datetime_str
)
import dl_utils.save_io as savio
from dl_utils.tokenizer import Tokenizer

from utils import check_correct_count, get_counts, run_til_idx, print_tensors
from datas import make_systematic_dataset
import datas
import causal_models.num_equivalence_models as numeqv
#import causal_models.arithmetic as arith
from automated_utils import collect_activations

DEVICE = 0 if torch.cuda.is_available() else "cpu"

if __name__=="__main__":
    model_folders = []
    og_file_name = None
    exp_config = {
        "save_layer": None, # "rnns.0", # if true, will save a pickle of the argued layer. argue None to ignore
        "overwrite": True,
        "n_samples": 15,
        "keyword": "targ_count",
        "min_val": 1,
        "max_val": 30,
        "use_best": True,
        "causal_model": "count_only_cmod"
    }

    # Read in Command line arguments in to exp_config
    args = sys.argv[1:]
    for arg in args:
        if savio.is_model_folder(arg):
            model_folders.append(arg)
        elif savio.is_exp_folder(arg):
            mfs = savio.get_model_folders(
                arg, incl_full_path=True, incl_empty=False)
            for f in mfs:
                model_folders.append(f)
        elif "checkpt" in arg and ".pt" in arg:
            model_folders.append(arg)
        elif ".yaml" in arg or ".json" in arg:
            exp_config = {**exp_config, **savio.load_json_or_yaml(arg)}
        elif "=" in arg:
            key,val = arg.split("=")
            if str(val)=="None":
                val = None
            elif str(val).lower()=="true":
                val = True
            elif str(val).lower()=="false":
                val = False
            exp_config[key] = val
        else:
            og_file_name = arg

    if og_file_name is not None:
        print("Saving to", og_file_name)

    # Loop through argued model folders
    for model_folder in model_folders:
        print()
        print("Beginning", model_folder)
        checkpt = savio.load_checkpoint(
            model_folder,
            use_best=exp_config.get("use_best", True)
        )
        config = checkpt["config"]
        temp = smods.make_model(config)
        temp.load_state_dict(checkpt["state_dict"])
        model = temp.model
        tokenizer = Tokenizer(
            words=set(),
            unk_token=None,
            word2id=config["word2id"])

        # Fill in settings from the checkpoint but don't overwrite
        # command line arguments
        conf = {**exp_config}
        for k in config:
            if k not in conf: conf[k] = config[k]
            elif type(conf[k])!=type(config[k]):
                if type(config[k]) in {int, float, str, bool}:
                    conf[k] = type(config[k])(conf[k])
                else:
                    try:
                        conf[k] = int(conf[k])
                    except:
                        print("failed to convert", k, "with value of",
                            conf[k])
        if "task_config" in conf:
            for k in conf["task_config"]:
                if k not in conf: conf[k] = conf["task_config"][k]
            conf["task_config"] = None
        print("Config:")
        for k in sorted(conf.keys()):
            print("\t",k, conf[k])

        # Organize validation settings
        n_samples = conf.get("n_samples", 30) # Samps per target quant
        conf["n_samples"] = n_samples
        hold_outs = set() #conf.get("hold_outs", {4,9,14,17}) # {17,18,19,20}
        conf["hold_outs"] = set()
        conf["task_name"] = conf.get("task_name", "num_equivalence")
        if conf["task_name"]=="num_equivalence":
            if "min_val" not in conf:
                conf["min_val"] = conf.get("min_count", 1)
            if "max_val" not in conf:
                conf["max_val"] = conf.get("max_count", 20)
            causal_model = getattr(numeqv, conf["causal_model"])
        elif conf["task_name"]=="arithmetic":
            if "min_val" not in conf:
                conf["min_val"] = -100
            if "max_val" not in conf:
                conf["max_val"] = 100
            conf["step_size"] = conf.get("step_size", 7)
            causal_model = getattr(arith, conf["causal_model"])
            raise NotImplemented
        conf["output_layer"] = conf.get("output_layer", "lm_head")
        conf["seq_len"] = config.get("seq_len", None)

        # Potentially save the intermediate outputs at these layers
        layers = []
        if conf.get("save_layer", None) is not None:
            layers.append(conf["save_layer"])

        # Optionally argue a high level model to tag the tokens for
        # later analysis
        if causal_model is None:
            if not config.get("multi_trigger", False):
                causal_model =  numeqv.count_only_cmod # numeqv.demo_resp_count_cmod
                intr_var_key = "count" # "demo_count", "resp_count"
            ## Multi Pre Trigger
            elif config.get("pre_trigger", True):
                causal_model = numeqv.trigid_count_only_cmod # numeqv.trigid_demo_resp_count_cmod
                intr_var_key = "count" # "demo_idx", "count", "resp_count", "demo_count"
            ## Multi Post Trigger
            else:
                causal_model = numeqv.count_only_cmod # numeqv.demo_resp_count_cmod 
                intr_var_key = "count" # "demo_count", "resp_count"

        # Collect Data
        data_kwargs = {**conf}
        data_kwargs["hold_outs"] = set()
        data_kwargs["concat"] = False
        dataset,task_mask,info,samp_types,_ = make_systematic_dataset(
            **data_kwargs)

        # Collect causal variables from high level model
        kwargs = {**config}
        kwargs["trigger_id"] = None
        kwargs = {**kwargs, **info}
        vbls = []
        for sidx,seq in enumerate(dataset):
            vars,outpts = run_til_idx(causal_model, seq, **kwargs)
            vbls.append(vars)
            dataset[sidx][1:len(outpts)] = outpts[:-1]



        # Collect meta data for later PCA analyses
        step_data = dict()
        for k in vars[0].keys():
            step_data[k] = np.asarray(
                [[v[k] for v in vbl] for vbl in vbls]).reshape(-1)
        step_data["step"] = np.asarray(
            [[i for i in range(len(vbl))] for vbl in vbls]).reshape(-1)
        step_data["token_id"] = np.asarray(dataset).reshape(-1)
        step_data["task_mask"] = np.asarray(task_mask).reshape(-1)
        step_data["ep_idx"] = np.asarray(
            [[j for i in range(len(vbl))] for j,vbl in enumerate(vbls)]
            ).reshape(-1)
        if conf["task_name"]=="num_equivalence":
            min_count = conf["min_val"]
            max_count = conf["max_val"]
            nums = np.asarray(sorted(list(
                set(range(min_count,max_count+1))-hold_outs
            )))
            trigger_ids = info.get("trigger_ids", [7])
            step_data["n_targs"] = nums.repeat(
                (len(trigger_ids)*n_samples,))[:,None].repeat(
                    (len(vbls[0]),))
            step_data["trigger_id"] = np.tile(
                np.asarray(trigger_ids).repeat((n_samples,)),
                (max_count-min_count+1-len(hold_outs),))[:,None].repeat(
                    (len(vbls[0]),))
        for k in samp_types:
            step_data[k] = np.asarray([
              [stype for _ in range(len(vbls[0]))] for stype in samp_types[k]
            ]).reshape(-1)

        # Collect model responses
        model.eval()
        model.cuda()
        torch_dataset = torch.LongTensor(dataset)
        task_mask = torch.BoolTensor(task_mask)
        print("dset:", torch_dataset.shape)
        print("tmask:", task_mask.shape)
        with torch.no_grad():
            actvs = collect_activations(
                model=model,
                input_ids=torch_dataset,
                task_mask=task_mask,
                pad_mask=None,
                layers=layers+[conf["output_layer"]],
                batch_size=None,
                ret_pred_ids=True,
                to_cpu=True)
        print("Actvs:")
        accs = {
            **samp_types,
            "correct": [],
            "tok_acc": [],
        }
        for k in actvs.keys():
            print("  ", k, actvs[k].shape)
            if k=="pred_ids":
                output_ids = torch_dataset[:,1:].cpu()
                pred_ids = actvs[k][:,:-1].reshape(output_ids.shape).cpu()
                tmask = task_mask[:,1:].cpu()
                corrects = torch.zeros(output_ids.shape).long()
                corrects[tmask] = (pred_ids[tmask]==output_ids[tmask]).long()
                eqs = corrects.sum(-1)==tmask.long().sum(-1)
                accs["correct"] = eqs.tolist()
                tacc = corrects.sum(-1)/tmask.float().sum(-1)
                accs["tok_acc"] = tacc.tolist()
                acc = eqs.float().mean()
            actvs[k] = actvs[k].reshape(-1,actvs[k].shape[-1])
            print("  ", " "*len(k), actvs[k].shape)

        print("Examples:")
        for _ in range(5):
            print_tensors(
                output_ids[_],
                actvs["pred_ids"][_],
                task_mask[_],
                tokenizer)
        print("\nTrials Correct:", acc.item())
        if conf["task_name"]=="num_equivalence":
            sts = torch.LongTensor(samp_types[list(samp_types.keys())[0]])
            train_idx = sts<=20
            interp = eqs[train_idx].float().mean().item()
            extrap = eqs[~train_idx].float().mean().item()
            print("Interp Acc:", interp)
            print("Extrap Acc:", extrap)

        step_data["pred_id"] = actvs["pred_ids"].cpu().numpy().reshape(-1)

        # Save intermediate computations if desired
        if conf.get("save_layer", None) is not None:
            X = actvs[layers[0]].numpy()
            a = { "X": X, "step_data": step_data }
            end_pt = "layer-"+layers[0]+"-actvs.pickle"
            if savio.is_model_folder(model_folder):
                path = os.path.join(model_folder, end_pt)
            else: path = path+end_pt
            with open(path, 'wb') as handle:
                pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)

        # Construct dataframe of results
        df = {}
        for k in step_data.keys():
            # Shift entries so that preds are aligned with target tokens
            if k!= "pred_id": df[k] = step_data[k][1:]
            else: df[k] = step_data[k][:-1]
        df = pd.DataFrame(df)
        if conf["task_name"]=="num_equivalence":
            if "count" in step_data:
                count_key = "count"
            elif "demo_count" in step_data:
                count_key = "demo_count"
        df["equal"] = df["pred_id"]==df["token_id"]
        acc_df = df.loc[df["task_mask"].astype(bool)]
        acc_df = acc_df.groupby(["ep_idx"])["equal"].mean()
        acc = (acc_df.reset_index()["equal"]>=1).mean()
        print("Acc:", acc)

        # Save low granularity accuracies to model folder
        lowgran_df = pd.DataFrame(accs)
        end_name = "accuracies.csv"
        if not os.path.isdir(model_folder):
            splt = model_folder.split("/")
            model_folder = "/".join(splt[:-1])
            end_name = ".".join(splt[-1].split(".")[:-1])+"_"+end_name
        path = os.path.join(model_folder,end_name)
        if os.path.exists(path) and not exp_config["overwrite"]:
            lowgran_df.to_csv(path, header=False, index=False, mode="a")
        else:
            lowgran_df.to_csv(path, header=True, index=False)
        print("Coarse accuracies csv to", path)

        # Save recorded accuracy to checkpoint
        checkpt["val_correct"] = acc
        f = model_folder
        if not savio.is_model_folder(f):
            f = "/".join(model_folder.split("/")[:-1])
        savio.save_checkpt(
            save_dict=checkpt,
            save_folder=f,
            save_name="checkpt",
            epoch=checkpt["epoch"],
            ext=".pt",
            del_prev_sd=False,
        )

        # Save dataframe to model folder
        end_name = "validation.csv"
        if not os.path.isdir(model_folder):
            splt = model_folder.split("/")
            model_folder = "/".join(splt[:-1])
            end_name = ".".join(splt[-1].split(".")[:-1])+"_"+end_name
        path = os.path.join(model_folder,end_name)
        if os.path.exists(path) and not exp_config["overwrite"]:
            df.to_csv(path, header=False, index=False, mode="a")
        else:
            df.to_csv(path, header=True, index=False)
        print("Saved csv to", path)

        #plot_df = df.loc[df["token_id"].isin(targs)]
        #avg = plot_df.groupby(["targ_count"])["equal"].mean()
        #fig = plt.figure()
        #sns.lineplot(x="targ_count", y="equal", data=plot_df)
        #mc = config.get("max_count",20)
        #plt.plot([mc,mc],[0,1],"k--")
        #plt.ylabel("Accuracy")
        #plt.xlabel("Target Count")
        #plt.tight_layout()
        #path = path.replace(".csv", ".png")
        #plt.savefig(path, bbox_inches="tight")
        #print("Saved figure to", path)
        ##except:
        ##    print("Exception Ocurred! Moving on to next model")

