"""
Important things to note when using this script. You will need to 
argue a keyword argument to search over the systematic dataset,
and you will need to argue a min and max value corresponding to this
keyword.
"""
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import sys
if sum([int(f=="dl_utils") for f in os.listdir("./")])==0:
    sys.path.append("../")
else:
    sys.path.append("./")
import time

import seq_models as smods

import dl_utils
from dl_utils.utils import (
    pad_to, get_mask_past_id, num2base, device_fxn, get_datetime_str
)
import dl_utils.save_io as savio

from utils import check_correct_count, get_counts, run_til_idx
from datas import make_systematic_dataset
import datas
import causal_models.num_equivalence as numeqv
from automated_utils import collect_activations

DEVICE = 0 if torch.cuda.is_available() else "cpu"

if __name__=="__main__":
    model_folders = []
    og_file_name = None
    exp_config = {
        "save_layer": None, # "rnns.0", # if true, will save a pickle of the argued layer. argue None to ignore
        "overwrite": True,
        "all_checkpts": False,
        "n_samples": 15,
        "keyword": "targ_count",
        "min_val": 1,
        "max_val": 30,
        "use_best": True,
    }

    # Read in Command line arguments in to exp_config
    args = sys.argv[1:]
    for arg in args:
        if savio.is_model_folder(arg):
            model_folders.append(arg)
        elif savio.is_exp_folder(arg):
            mfs = savio.get_model_folders(
                arg, incl_full_path=True, incl_empty=False)
            for f in mfs:
                model_folders.append(f)
        elif "checkpt" in arg and ".pt" in arg:
            model_folders.append(arg)
        elif ".yaml" in arg or ".json" in arg:
            exp_config = {**exp_config, **savio.load_json_or_yaml(arg)}
        elif "=" in arg:
            key,val = arg.split("=")
            if str(val)=="None":
                val = None
            elif str(val).lower()=="true":
                val = True
            elif str(val).lower()=="false":
                val = False
            exp_config[key] = val
        else:
            og_file_name = arg

    if og_file_name is not None:
        print("Saving to", og_file_name)

    # Loop through argued model folders
    for model_folder in model_folders:
        print()
        print("Beginning", model_folder)
        checkpt = savio.load_checkpoint( model_folder,)
        config = checkpt["config"]
        word2id = config.get(
            "word2id",
            {i:i for i in range(config["n_tokens"])})
        id2word = {v:k for k,v in word2id.items()}

        # Fill in settings from the checkpoint but don't overwrite
        # command line arguments
        conf = {**exp_config}
        for k in config:
            if k not in conf: conf[k] = config[k]
            elif type(conf[k])!=type(config[k]):
                if type(config[k]) in {int, float, str, bool}:
                    conf[k] = type(config[k])(conf[k])
                else:
                    try:
                        conf[k] = int(conf[k])
                    except:
                        print("failed to convert", k, "with value of",
                            conf[k])
        if "task_config" in conf:
            for k in conf["task_config"]:
                if k not in conf: conf[k] = conf["task_config"][k]
            conf["task_config"] = None
        print("Config:")
        for k in sorted(conf.keys()):
            print("\t",k, conf[k])

        # Organize validation settings
        n_samples = conf.get("n_samples", 30) # Samps per target quant
        conf["n_samples"] = n_samples
        hold_outs = set() #conf.get("hold_outs", {4,9,14,17}) # {17,18,19,20}
        conf["hold_outs"] = set()
        conf["task_name"] = conf.get("task_name", "num_equivalence")
        if conf["task_name"]=="num_equivalence":
            if "min_val" not in conf:
                conf["min_val"] = conf.get("min_count", 1)
            if "max_val" not in conf:
                conf["max_val"] = conf.get("max_count", 20) + 10
            causal_model = numeqv.demo_resp_count_cmod
        elif conf["task_name"]=="arithmetic":
            if "min_val" not in conf:
                conf["min_val"] = -100
            if "max_val" not in conf:
                conf["max_val"] = 100
            conf["step_size"] = conf.get("step_size", 7)
            causal_model = None # basic_arith_cmod
            raise NotImplemented
        conf["output_layer"] = conf.get("output_layer", "lm_head")
        conf["seq_len"] = None

        # Potentially save the intermediate outputs at these layers
        layers = []
        if conf.get("save_layer", None) is not None:
            layers.append(conf["save_layer"])

        # Optionally argue a high level model to tag the tokens for
        # later analysis
        if causal_model is None:
            if not config.get("multi_trigger", False):
                causal_model =  numeqv.count_only_cmod # numeqv.demo_resp_count_cmod
                intr_var_key = "count" # "demo_count", "resp_count"
            ## Multi Pre Trigger
            elif config.get("pre_trigger", True):
                causal_model = numeqv.trigid_count_only_cmod # numeqv.trigid_demo_resp_count_cmod
                intr_var_key = "count" # "demo_idx", "count", "resp_count", "demo_count"
            ## Multi Post Trigger
            else:
                causal_model = numeqv.count_only_cmod # numeqv.demo_resp_count_cmod 
                intr_var_key = "count" # "demo_count", "resp_count"

        # Collect Data
        data_kwargs = {**conf}
        data_kwargs["hold_outs"] = set()
        data_kwargs["concat"] = False
        dataset,task_mask,info,samp_types,_ = make_systematic_dataset(
            **data_kwargs)

        #dset = dataset
        #output_ids = dataset
        #tmask = task_mask
        #sep = ","
        #print("Examples:")
        #for i in range(0, len(dataset), 15):
        #    try:
        #        l = len(dset[i])
        #        #print("Idxs :", sep.join(["{:2}".format(_) for _ in range(l)]))
        #        print("Targs:", sep.join(
        #            ["{:2}".format(id2word[s][:2]) for s in dset[i]]))
        #        #print("Tmask:",
        #        #    sep.join(["{:2}".format(s) for s in tmask[i]]))
        #        #print()
        #    except:
        #        print("Failed to print example")
        #assert False

        # Collect causal variables from high level model
        kwargs = {**config}
        kwargs["trigger_id"] = None
        kwargs = {**kwargs, **info}
        vbls = []
        for seq in dataset:
            vars,_ = run_til_idx(causal_model, seq, **kwargs)
            vbls.append(vars)

        # Collect meta data for later PCA analyses
        step_data = dict()
        for k in vars[0].keys():
            step_data[k] = np.asarray(
                [[v[k] for v in vbl] for vbl in vbls]).reshape(-1)
        step_data["step"] = np.asarray(
            [[i for i in range(len(vbl))] for vbl in vbls]).reshape(-1)
        step_data["token_id"] = np.asarray(dataset).reshape(-1)
        step_data["task_mask"] = np.asarray(task_mask).reshape(-1)
        step_data["ep_idx"] = np.asarray(
            [[j for i in range(len(vbl))] for j,vbl in enumerate(vbls)]
            ).reshape(-1)
        if conf["task_name"]=="num_equivalence":
            min_count = conf["min_val"]
            max_count = conf["max_val"]
            nums = np.asarray(sorted(list(
                set(range(min_count,max_count+1))-hold_outs
            )))
            trigger_ids = info.get("trigger_ids", [7])
            step_data["n_targs"] = nums.repeat(
                (len(trigger_ids)*n_samples,))[:,None].repeat(
                    (len(vbls[0]),))
            step_data["trigger_id"] = np.tile(
                np.asarray(trigger_ids).repeat((n_samples,)),
                (max_count-min_count+1-len(hold_outs),))[:,None].repeat(
                    (len(vbls[0]),))
        for k in samp_types:
            step_data[k] = np.asarray([
              [stype for _ in range(len(vbls[0]))] for stype in samp_types[k]
            ]).reshape(-1)

        if exp_config["all_checkpts"]:
            checkpt_files = savio.get_checkpoints(model_folder)
        else:
            checkpt_files = [model_folder]

        df_list = []
        lowgran_dfs = []
        for cfile in checkpt_files:
            use_best = exp_config.get("use_best", True) and\
                       cfile==model_folder
            checkpt = savio.load_checkpoint(
                cfile,
                use_best=use_best,
            )
            config = checkpt["config"]
            try:
                temp = smods.make_model(config)
                temp.load_state_dict(checkpt["state_dict"])
            except:
                # Probably deleted during training
                continue
            model = temp.model
            # Collect model responses
            model.eval()
            model.cuda()
            torch_dataset = torch.LongTensor(dataset)
            task_mask = torch.BoolTensor(task_mask)
            with torch.no_grad():
                actvs = collect_activations(
                    model=model,
                    input_ids=torch_dataset,
                    task_mask=task_mask,
                    pad_mask=None,
                    layers=layers+[conf["output_layer"]],
                    batch_size=None,
                    ret_pred_ids=True,
                    to_cpu=True)
            print("Actvs:")
            accs = {
                **samp_types,
                "correct": [],
                "tok_acc": [],
            }
            for k in actvs.keys():
                print("  ", k, actvs[k].shape)
                if k=="pred_ids":
                    output_ids = torch_dataset[:,1:].cpu()
                    pred_ids = actvs[k][:,:-1].reshape(output_ids.shape).cpu()
                    tmask = task_mask[:,1:].cpu()
                    corrects = torch.zeros(output_ids.shape).long()
                    corrects[tmask] = (pred_ids[tmask]==output_ids[tmask]).long()
                    eqs = corrects.sum(-1)==tmask.long().sum(-1)
                    accs["correct"] = eqs.tolist()
                    tacc = corrects.sum(-1)/tmask.float().sum(-1)
                    accs["tok_acc"] = tacc.tolist()
                    acc = eqs.float().mean()
                actvs[k] = actvs[k].reshape(-1,actvs[k].shape[-1])
                print("  ", " "*len(k), actvs[k].shape)

            dset = torch_dataset
            pred_ids = actvs["pred_ids"]
            output_ids = torch_dataset
            tmask = task_mask
            sep = ","
            print("Examples:")
            for i in range(min(3,len(dset))):
                try:
                    l = max(len(output_ids[i]),len(pred_ids[i]), len(dset[i]))
                    print("Idxs :", sep.join(["{:2}".format(_) for _ in range(l)]))
                    print("Targs:", sep.join(
                        ["{:2}".format(id2word[s][:2]) for s in output_ids[i].tolist()]))
                    print("Preds:", sep.join(
                        [" 1"]+["{:2}".format(id2word[s][:2]) for s in pred_ids[i].tolist()]))
                    print("Tmask:",
                        sep.join(["{:2}".format(s) for s in tmask[i].tolist()]))
                    print()
                except:
                    print("Failed to print example")

            print("\nTrials Correct:", acc.item())
            if conf["task_name"]=="num_equivalence":
                sts = torch.LongTensor(samp_types[list(samp_types.keys())[0]])
                train_idx = sts<=20
                interp = eqs[train_idx].float().mean().item()
                extrap = eqs[~train_idx].float().mean().item()
                print("Interp Acc:", interp)
                print("Extrap Acc:", extrap)

            step_data["pred_id"] = actvs["pred_ids"].cpu().numpy().reshape(-1)

            # Save intermediate computations if desired
            if conf.get("save_layer", None) is not None:
                X = actvs[layers[0]].numpy()
                a = { "X": X, "step_data": step_data }
                end_pt = "layer-"+layers[0]+"-actvs.pickle"
                if savio.is_model_folder(model_folder):
                    path = os.path.join(model_folder, end_pt)
                else: path = path+end_pt
                with open(path, 'wb') as handle:
                    pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)

            # Construct dataframe of results
            df = {}
            for k in step_data.keys():
                # Shift entries so that preds are aligned with target tokens
                if k!= "pred_id": df[k] = step_data[k][1:]
                else: df[k] = step_data[k][:-1]
            df = pd.DataFrame(df)
            if conf["task_name"]=="num_equivalence":
                if "count" in step_data:
                    count_key = "count"
                elif "demo_count" in step_data:
                    count_key = "demo_count"
            df["equal"] = df["pred_id"]==df["token_id"]
            df["cfile"] = cfile
            df["epoch"] = checkpt.get("epoch", cfile.split("_")[-1].split(".")[0])
            acc_df = df.loc[df["task_mask"].astype(bool)]
            acc_df = acc_df.groupby(["ep_idx"])["equal"].mean()
            acc = (acc_df.reset_index()["equal"]>=1).mean()
            print("Acc:", acc)
            df_list.append(df)

            # Collect low granularity accuracies
            lowgran_df = pd.DataFrame(accs)
            lowgran_df["cfile"] = cfile
            lowgran_df["epoch"] = checkpt.get("epoch", cfile.split("_")[-1].split(".")[0])
            lowgran_dfs.append(lowgran_df)

            # Save recorded accuracy to checkpoint
            checkpt["val_correct"] = acc
            f = model_folder
            if not savio.is_model_folder(f):
                f = "/".join(model_folder.split("/")[:-1])
            savio.save_checkpt(
                save_dict=checkpt,
                save_folder=f,
                save_name="checkpt",
                epoch=checkpt["epoch"],
                ext=".pt",
                del_prev_sd=False,
            )

        # Save low granularity accuracies to model folder
        lowgran_df = pd.concat(lowgran_dfs, sort=True)
        end_name = "accuracies.csv"
        if not os.path.isdir(model_folder):
            splt = model_folder.split("/")
            model_folder = "/".join(splt[:-1])
            end_name = ".".join(splt[-1].split(".")[:-1])+"_"+end_name
        path = os.path.join(model_folder,end_name)
        if os.path.exists(path) and not exp_config["overwrite"]:
            lowgran_df.to_csv(path, header=False, index=False, mode="a")
        else:
            lowgran_df.to_csv(path, header=True, index=False)
        print("Coarse accuracies csv to", path)

        # Save main dataframe to model folder
        df = pd.concat(df_list, sort=True)
        end_name = "validation.csv"
        if not os.path.isdir(model_folder):
            splt = model_folder.split("/")
            model_folder = "/".join(splt[:-1])
            end_name = ".".join(splt[-1].split(".")[:-1])+"_"+end_name
        path = os.path.join(model_folder,end_name)
        if os.path.exists(path) and not exp_config["overwrite"]:
            df.to_csv(path, header=False, index=False, mode="a")
        else:
            df.to_csv(path, header=True, index=False)
        print("Saved csv to", path)

        #plot_df = df.loc[df["token_id"].isin(targs)]
        #avg = plot_df.groupby(["targ_count"])["equal"].mean()
        #fig = plt.figure()
        #sns.lineplot(x="targ_count", y="equal", data=plot_df)
        #mc = config.get("max_count",20)
        #plt.plot([mc,mc],[0,1],"k--")
        #plt.ylabel("Accuracy")
        #plt.xlabel("Target Count")
        #plt.tight_layout()
        #path = path.replace(".csv", ".png")
        #plt.savefig(path, bbox_inches="tight")
        #print("Saved figure to", path)
        ##except:
        ##    print("Exception Ocurred! Moving on to next model")

