"""
This script takes a model and a trained interchange module and validates
the interchange module on a dataset. This script tracks the demo count,
resp count, and phase of both the source sequence and the destination
sequence at the interchange index as well as the interchange key and the
final response count starting from the trigger token regardless of
interchange index (this allows us to track cases where we transfer from
response into demo and we want to know how many additional demo steps
occurred before the response).

Use this script in the following way:

    $ python3 das_validation.py /path/to/model/ kwarg1=val1 kwarg2=val2

PSEUDO:
    read in model files and loop over them
    collect data
    loop over das trainings
    evaluate das
    record results
"""
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import sys
if sum([int(f=="dl_utils") for f in os.listdir("./")])==0:
    sys.path.append("../")
else:
    sys.path.append("./")
import time

import seq_models as smods

import dl_utils
from dl_utils.utils import (
    pad_to, get_mask_past_id, num2base, device_fxn, get_datetime_str
)
import dl_utils.save_io as savio

from utils import check_correct_count, get_counts, run_til_idx
import datas
from datas import (
    get_sequence, sample_sequence, make_dataset,
    make_systematic_intrv_dataset,
    extract_metrics
)
from causal_models import *
from automated_utils import collect_activations, das_eval
from das import load_alignment

DEVICE = 0 if torch.cuda.is_available() else "cpu"

def read_in_args(args, config, model_folders=None, intrv_files=None):
    """
    Args:
        args: list
            the list of command line args (probably want sys.argv[1:])
        config: dict
            a dict to collect the command line args
    """
    if model_folders is None: model_folders = []
    if intrv_files is None: intrv_files = []
    og_file_name = None

    for arg in args:
        if savio.is_model_folder(arg):
            model_folders.append(arg)
        elif savio.is_exp_folder(arg):
            model_folders += savio.get_model_folders(
                arg,incl_full_path=True)
        elif "checkpt" in arg and ".pt" in arg:
            model_folders.append(arg)
        elif "intervene.p" in arg:
            # offer way to read in multiple intermodules
            # consider using paths to find model and intrv matches
            intrv_files.append(arg)
        elif ".yaml" in arg or ".json" in arg:
            config = {**config, **savio.load_json_or_yaml(arg)}
        elif "=" in arg:
            key,val = arg.split("=")
            if str(val)=="None":
                val = None
            config[key] = val
        else:
            og_file_name = arg
    # End read in command line args
    return config, model_folders, og_file_name

def get_intrv_files(model_folder, checkpt):
    """
    model_folder: str
        full path to model folder
    checkpt: str
        full or partial path to specific checkpt
    """
    paths = []
    c = checkpt.split("/")[-1]
    if not savio.is_model_folder(model_folder):
        model_folder = "/".join(model_folder.split("/")[:-1])
    for f in sorted(os.listdir(model_folder)):
        if "intervene.p" in f and c in f:
            paths.append(os.path.join(model_folder, f))
    return paths


if __name__=="__main__":
    exp_config, model_folders, og_file_name = read_in_args(
        sys.argv[1:], {"n_samples":10})

    if og_file_name is not None:
        print("Saving to", og_file_name)

    for model_folder in model_folders:
        print()
        print("STARTING NEW MODEL FOLDER", model_folder)
        checkpt = savio.load_checkpoint(model_folder)
        config = checkpt["config"]

        model_conf = {**exp_config}
        for k in config:
            if k not in model_conf:
                model_conf[k] = config[k]
            elif type(model_conf[k])!=type(config[k]):
                if type(config[k]) in {int, float, str, bool}:
                    model_conf[k] = type(config[k])(model_conf[k])
                else:
                    try:
                        model_conf[k] = int(model_conf[k])
                    except:
                        print("failed to convert", k, "with value of",
                            model_conf[k])
        print("Config:")
        for k in sorted(model_conf.keys()):
            print("\t",k, model_conf[k])

        print("MAKING DATA...")
        kwargs = {
            "n_samples": model_conf.get("n_samples", 15),
            "min_count": model_conf.get("min_count", 1),
            "max_count": model_conf.get("max_count", 20),
            "max_source_count": model_conf.get("max_source_count", 20),
            "max_demo_tokens": model_conf.get("max_demo_tokens", None),
            "n_demo_types": model_conf.get("n_demo_types", 3),
            "multi_trigger": model_conf.get("multi_trigger", False),
            "pre_trigger": model_conf.get("pre_trigger", False),
            "seq_len": exp_config.get("seq_len", None),
            "hold_outs": model_conf.get("hold_outs", {4,9,14,17}),
            "offset": model_conf.get("offset", 0),
            "pad_id": model_conf.get("pad_id", 0),
            "trigger_ids": model_conf.get("trigger_ids", [7]),
            "copy_task": model_conf.get("copy_task", False),
            "base_hmod": model_conf.get("high_model", count_only_hmod),
            "intr_var_key": model_conf.get("intr_var_key","count"),
            "trigger_steps": model_conf.get("trigger_steps", {0,1,2,3}),
            "dest_count_step": int(model_conf.get("dest_count_step", 1)),
            "source_count_step": int(model_conf.get("source_count_step",4)),
            "ret_info": True,
        }
        data_dict, info =   make_systematic_intrv_dataset(**kwargs)
        base_data     = data_dict["base_data"]
        base_idxs     = data_dict["base_idxs"]
        base_vars     = data_dict["base_vars"]
        base_counts   = data_dict["base_targ_counts"]
        source_data   = data_dict["source_data"]
        source_idxs   = data_dict["source_idxs"]
        source_vars   = data_dict["source_vars"]
        source_counts = data_dict["source_targ_counts"]
        print("Data Shape:", base_data.shape)

        dest_df = extract_metrics(
            seqs=base_data,
            idxs=base_idxs,
            vars=base_vars,
            trg_counts=base_counts,
            info=info)
        dest_df = pd.DataFrame(dest_df)
        source_df = extract_metrics(
            seqs=source_data,
            idxs=source_idxs,
            vars=source_vars,
            trg_counts=source_counts,
            info=info)
        source_df = pd.DataFrame(source_df)
        dest_df.columns = ["dest_"+c for c in dest_df.columns]
        source_df.columns = ["source_"+c for c in source_df.columns]
        data_df = pd.merge(
            left=dest_df,
            right=source_df,
            left_index=True,
            right_index=True)

        pad_id = info["pad_id"]
        bos_id = info["bos_id"] 
        eos_id = info["eos_id"] 
        base_data = {
          "input_ids": base_data[:,:-1],
          "pad_mask": (base_data[:,:-1]==pad_id)|\
                      (base_data[:,:-1]==eos_id),
          "output_ids": base_data[:,1:],
          "output_mask": (base_data[:,1:]==pad_id)|\
                         (base_data[:,1:]==bos_id),
        }
        source_data = {
          "input_ids": source_data[:,:-1],
          "pad_mask": (source_data[:,:-1]==pad_id)|\
                      (source_data[:,:-1]==eos_id),
        }

        checkpts = savio.get_checkpoints(model_folder)
        for checkpt_file in checkpts:
            intrv_files = get_intrv_files(model_folder,checkpt_file)
            print(intrv_files)
            if len(intrv_files)==0: continue
            checkpt = savio.load_checkpoint(checkpt_file)
            config = checkpt["config"]
            temp = smods.make_model(config)
            temp.load_state_dict(checkpt["state_dict"])
            model = temp.model
            model.eval()
            model.to(DEVICE)
            layer = "rnns.0" #assumes recurrent. breaks transformers

            try:
                batch_size = None if "batch_size" not in exp_config else\
                                    int(exp_config["batch_size"])
                with torch.no_grad():
                    source_actvs = collect_activations(
                        model=model,
                        input_ids=source_data["input_ids"],
                        pad_mask=source_data.get("pad_mask", None),
                        layers=[layer],
                        batch_size=batch_size,
                        to_cpu=True)[layer]
            except:
                source_actvs = None
                print("Could not precompute source activations")

            for intrv_file in intrv_files:
                print("Checkpt:", checkpt_file)
                print("Intrv:", intrv_file)
                try:
                    intr_modu, intr_conf = load_alignment(
                        intrv_file,ret_config=True) 
                except RuntimeError:
                    print("Failed to load intervention module, continuing...")
                    continue
                intr_conf = {
                    **intr_conf["config"],
                    **intr_conf.get("meta_config",{})}

                intr_modu.to(DEVICE)
                intr_modu.eval()

                el = intr_conf.get("layer", "rnns.0")
                if el!=layer: sactvs = None
                else: sactvs = source_actvs
                metrics = das_eval(
                    model=model,
                    layer=layer,
                    intr_modu=intr_modu,
                    base_data=base_data,
                    base_idxs=base_idxs,
                    source_data=source_data,
                    source_idxs=source_idxs,
                    source_actvs=sactvs,
                    info=info,)

                df = pd.DataFrame(metrics)
                intr_df = pd.merge(
                    left=data_df,
                    right=df,
                    left_index=True,
                    right_index=True)
                save_path = ".".join(intrv_file.split(".")[:-1])+"_val.csv"
                intr_df.to_csv(save_path, header=True, index=False, mode="w")

