"""
This script takes a model and a trained interchange module and validates
the interchange module on a dataset. This script tracks the demo count,
resp count, and phase of both the source sequence and the destination
sequence at the interchange index as well as the interchange key and the
final response count starting from the trigger token regardless of
interchange index (this allows us to track cases where we transfer from
response into demo and we want to know how many additional demo steps
occurred before the response).

Use this script in the following way:

    $ python3 das_validation.py /path/to/model/ kwarg1=val1 kwarg2=val2

PSEUDO:
    read in model files and loop over them
    collect data
    loop over das trainings
    evaluate das
    record results
"""
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import sys
if sum([int(f=="dl_utils") for f in os.listdir("./")])==0:
    sys.path.append("../")
else:
    sys.path.append("./")
import time

import seq_models as smods

import dl_utils
from dl_utils.utils import (
    pad_to, get_mask_past_id, num2base, device_fxn, get_datetime_str
)
import dl_utils.save_io as savio

from utils import check_correct_count, get_counts, run_til_idx
import datas
from datas import (
    make_systematic_intrv_dataset, extract_metrics
)
from causal_models.num_equivalence import *
from automated_utils import collect_activations, das_eval
from das import load_alignment

DEVICE = 0 if torch.cuda.is_available() else "cpu"

def read_in_args(args, config, model_folders=None, intrv_files=None):
    """
    Args:
        args: list
            the list of command line args (probably want sys.argv[1:])
        config: dict
            a dict to collect the command line args
    """
    if model_folders is None: model_folders = []
    if intrv_files is None: intrv_files = []
    og_file_name = None

    for arg in args:
        if savio.is_model_folder(arg):
            model_folders.append(arg)
        elif savio.is_exp_folder(arg):
            model_folders += savio.get_model_folders(
                arg,incl_full_path=True)
        elif "checkpt" in arg and ".pt" in arg:
            model_folders.append(arg)
        elif "intervene.p" in arg:
            # offer way to read in multiple intermodules
            # consider using paths to find model and intrv matches
            intrv_files.append(arg)
        elif ".yaml" in arg or ".json" in arg:
            config = {**config, **savio.load_json_or_yaml(arg)}
        elif "=" in arg:
            key,val = arg.split("=")
            if str(val)=="None":
                val = None
            elif str(val).lower() in {"false", "true"}:
                val = str(val).lower()=="true"
            config[key] = val
        else:
            og_file_name = arg
    # End read in command line args
    return config, model_folders, og_file_name

def get_intrv_files(model_folder, use_best=True):
    paths = []
    if not savio.is_model_folder(model_folder):
        model_folder = "/".join(model_folder.split("/")[:-1])
    ext = "intervene.best.p" if use_best else "intervene.p"
    for f in sorted(os.listdir(model_folder)):
        if ext in f:
            paths.append(os.path.join(model_folder, f))
    return paths


if __name__=="__main__":
    exp_config, model_folders, og_file_name = read_in_args(
        sys.argv[1:], {"n_samples":1})

    if og_file_name is not None:
        print("Saving to", og_file_name)

    use_best = exp_config.get("use_best", True) # use last intervention checkpt if false
    no_checkpts = False # will not analyze checkpoint intervention saves
    if "no_checkpts" in exp_config:
        no_checkpts = exp_config["no_checkpts"]
    layer = None
    for model_folder in model_folders:
        print()
        print("STARTING NEW MODEL FOLDER", model_folder)
        checkpt = savio.load_checkpoint(model_folder)
        config = checkpt["config"]
        if config["model_type"]=="Transformer":
            print("Still not implemented for Transformers")
            continue
        temp = smods.make_model(config)
        temp.load_state_dict(checkpt["state_dict"])
        model = temp.model
        model.eval()
        model.to(DEVICE)

        model_conf = {**exp_config}
        for k in config:
            if k not in model_conf:
                model_conf[k] = config[k]
            elif type(model_conf[k])!=type(config[k]):
                if type(config[k]) in {int, float, str, bool}:
                    model_conf[k] = type(config[k])(model_conf[k])
                else:
                    try:
                        model_conf[k] = int(model_conf[k])
                    except:
                        print("failed to convert", k, "with value of",
                            model_conf[k])
        model_conf["task_config"]["hold_outs"] = set()
        model_conf["hold_outs"] = set()
        model_conf["task_config"]["min_count"] = int(model_conf.get("min_count",model_conf["task_config"].get("min_count", 1)))
        model_conf["task_config"]["max_count"] += int(model_conf.get("source_incr",3))-1
        model_conf["seq_len"] = None
        model_conf["task_config"]["seq_len"] = None

        print("Config:")
        for k in sorted(model_conf.keys()):
            if type(model_conf[k])==dict:
                print("\t",k+":")
                for kk in sorted(model_conf[k].keys()):
                    print("\t\t", kk,"-", model_conf[k][kk])
            else:
                print("\t",k,"-", model_conf[k])

        print("MAKING DATA...")
        kwargs = {
            "n_samples":   model_conf.get("n_samples", 1),
            "dest_incr":   int(model_conf.get("dest_incr", 1)),# increments intrv count
            "source_incr": int(model_conf.get("source_incr",3)),# increments intrv count
            "dest_step":   int(model_conf.get("dest_step", 1)), # increments target count
            "source_step": int(model_conf.get("source_step", 1)),# increments target count
            "seq1_kwargs": {
                **model_conf,
                "causal_model": model_conf.get("causal_model", count_only_cmod),
                "intr_var_key": model_conf.get("intr_var_key","count"),
            },
            "trigger_steps": model_conf.get("trigger_steps", {0,1,2}),
            "ret_info": True,
        }
        data_dict, info = make_systematic_intrv_dataset(**kwargs)
        base_data     = data_dict["base_data"]
        base_idxs     = data_dict["base_idxs"]
        base_vars     = data_dict["base_vars"]
        base_counts   = data_dict["base_types"]
        base_tmasks   = data_dict["base_tmasks"]
        source_data   = data_dict["source_data"]
        source_idxs   = data_dict["source_idxs"]
        source_vars   = data_dict["source_vars"]
        source_counts = data_dict["source_types"]
        source_tmasks = data_dict["source_tmasks"]
        print("Data Shape:", base_data.shape)


        # TODO
        sep = ","
        for i in range(0, len(source_data), 1):
            bidx = base_idxs[i]
            sidx = source_idxs[i]
            l = max(
                len(base_data[i]),
                len(source_data[i]),)
            x = 40
            #print("Idx:", sep.join(
            #    ["{:2}".format(_) for _ in range(l)]))

            bse = ["{:2}".format(s) for s in base_data[i].tolist()]
            bse[bidx] = "##"
            print("Bse:", sep.join(bse[:x]))

            src = ["{:2}".format(s) for s in source_data[i].tolist()]
            src[sidx] = "##"
            print("Src:", sep.join(src[:x]))

            tsk = ["{:2}".format(s) for s in base_tmasks[i].tolist()]
            print("Tsk:", sep.join(tsk[:x]))
            #print()
            if i>100: break


        dest_df = extract_metrics(
            seqs=base_data,
            idxs=base_idxs,
            vars=base_vars,
            trg_counts=base_counts,
            info=info)
        dest_df = pd.DataFrame(dest_df)
        source_df = extract_metrics(
            seqs=source_data,
            idxs=source_idxs,
            vars=source_vars,
            trg_counts=source_counts,
            info=info)
        source_df = pd.DataFrame(source_df)
        dest_df.columns = ["dest_"+c for c in dest_df.columns]
        source_df.columns = ["source_"+c for c in source_df.columns]
        data_df = pd.merge(
            left=dest_df,
            right=source_df,
            left_index=True,
            right_index=True)

        pad_id = info.get("pad_id", 0)
        bos_id = info.get("bos_id", 1)
        eos_id = info.get("eos_id", 2)
        base_data = {
          "input_ids": base_data[:,:-1],
          "pad_mask": (base_data[:,:-1]==pad_id)|\
                      (base_data[:,:-1]==eos_id),
          "task_mask": base_tmasks,
          "output_ids": base_data[:,1:],
          "output_pad_mask": (base_data[:,1:]==pad_id)|\
                         (base_data[:,1:]==bos_id),
        }
        source_data = {
          "input_ids": source_data[:,:-1],
          "pad_mask": (source_data[:,:-1]==pad_id)|\
                      (source_data[:,:-1]==eos_id),
          "task_mask": source_tmasks,
        }

        intrv_files = get_intrv_files(model_folder, use_best=use_best)
        for intrv_file in intrv_files:
            if no_checkpts and "checkpt_" in intrv_file:
                continue
            print("Model:", model_folder)
            print("Intrv:", intrv_file)
            intr_modu, intr_conf = load_alignment(
                intrv_file,
                ret_config=True) 
            try:
                intr_modu, intr_conf = load_alignment(
                    intrv_file,ret_config=True) 
            except RuntimeError:
                print("Failed to load intervention module, continuing...")
                continue
            intr_conf = {
                **intr_conf["config"],
                **intr_conf.get("meta_config",{})}

            intr_modu.to(DEVICE)
            intr_modu.eval()

            bsize = None if "batch_size" not in exp_config else\
                                    int(exp_config["batch_size"])
            try:
                metrics = das_eval(
                    model=model,
                    layer=intr_conf.get("layer", None),
                    intr_modu=intr_modu,
                    base_data=base_data,
                    base_idxs=base_idxs,
                    source_data=source_data,
                    source_idxs=source_idxs,
                    batch_size=bsize,
                    info=info,)
            except:
                print("Error in evaluation... skipping", intrv_file)
                continue

            df = pd.DataFrame(metrics)
            intr_df = pd.merge(
                left=data_df,
                right=df,
                left_index=True,
                right_index=True)
            save_path = ".".join(intrv_file.split(".")[:-1])+"_val.csv"
            intr_df.to_csv(save_path, header=True, index=False, mode="w")
            #try:
            #    metrics = das_eval(
            #        model=model,
            #        layer=intr_conf.get("layer", None),
            #        intr_modu=intr_modu,
            #        base_data=base_data,
            #        base_idxs=base_idxs,
            #        source_data=source_data,
            #        source_idxs=source_idxs,
            #        batch_size=bsize,
            #        info=info,)

            #    df = pd.DataFrame(metrics)
            #    intr_df = pd.merge(
            #        left=data_df,
            #        right=df,
            #        left_index=True,
            #        right_index=True)
            #    save_path = ".".join(intrv_file.split(".")[:-1])+"_val.csv"
            #    intr_df.to_csv(save_path, header=True, index=False, mode="w")
            #except:
            #    print("Exception in evaluation, ignoring and moving on...")
            #    continue

