import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import sys
if sum([int(f=="dl_utils") for f in os.listdir("./")])==0:
    sys.path.append("../")
else:
    sys.path.append("./")
import time

import seq_models as smods

import dl_utils
from dl_utils.utils import (
    pad_to, get_mask_past_id, num2base, device_fxn, get_datetime_str
)
import dl_utils.save_io as savio

from utils import check_correct_count, get_counts, run_til_idx
import datas
from datas import (
    get_sequence, sample_sequence, make_dataset,
    get_intervention_dataset,
)
from causal_models import *
from automated_utils import collect_activations

DEVICE = 0 if torch.cuda.is_available() else "cpu"

if __name__=="__main__":
    model_folders = []
    og_file_name = None
    exp_config = dict()
    args = sys.argv[1:]
    for arg in args:
        if savio.is_model_folder(arg):
            model_folders.append(arg)
        elif savio.is_exp_folder(arg):
            mfs = savio.get_model_folders(
                arg, incl_full_path=True, incl_empty=False)
            for f in mfs:
                model_folders.append(f)
        elif "checkpt" in arg and ".pt" in arg:
            model_folders.append(arg)
        elif ".yaml" in arg or ".json" in arg:
            exp_config = {**exp_config, **savio.load_json_or_yaml(arg)}
        elif "=" in arg:
            key,val = arg.split("=")
            if str(val)=="None":
                val = None
            exp_config[key] = val
        else:
            og_file_name = arg

    if og_file_name is not None:
        print("Saving to", og_file_name)

    for model_folder in model_folders:
        print()
        print("Beginning", model_folder)
        #try:
        checkpt = savio.load_checkpoint(model_folder)
        config = checkpt["config"]
        temp = smods.make_model(config)
        temp.load_state_dict(checkpt["state_dict"])
        model = temp.model

        for k in config:
            if k not in exp_config:
                exp_config[k] = config[k]
            elif type(exp_config[k])!=type(config[k]):
                if type(config[k]) in {int, float, str, bool}:
                    exp_config[k] = type(config[k])(exp_config[k])
                else:
                    try:
                        exp_config[k] = int(exp_config[k])
                    except:
                        print("failed to convert", k, "with value of",
                            exp_config[k])
        print("Config:")
        for k in sorted(exp_config.keys()):
            print("\t",k, exp_config[k])

        n_samples = exp_config.get("n_samples", 30) # Samps per target quant
        hold_outs = set() #exp_config.get("hold_outs", {4,9,14,17}) # {17,18,19,20}
        max_count = exp_config.get("max_count", 20)

        layers = ["rnns.0"]
        output_layer = "lm_head"
        min_count = exp_config.get("min_count", 1)
        max_demo_tokens = exp_config.get("max_demo_tokens", None)
        seq_len = 1+max_count*2+3 # config.get("seq_len", None),
        if max_demo_tokens: seq_len = seq_len-max_count+max_demo_tokens

        trigger_ids = exp_config.get("trigger_ids", [7])
        multi_trigger = exp_config.get("multi_trigger", False)
        pre_trigger = exp_config.get("pre_trigger", False)
        high_model = demo_resp_count_hmod

        if high_model is None:
            if not config.get("multi_trigger", False):
                high_model =  count_only_hmod # demo_resp_count_hmod
                intr_var_key = "count" # "demo_count", "resp_count"
            ## Multi Pre Trigger
            elif config.get("pre_trigger", True):
                high_model = trigid_count_only_hmod # trigid_demo_resp_count_hmod
                intr_var_key = "count" # "demo_idx", "count", "resp_count", "demo_count"
            ## Multi Post Trigger
            else:
                high_model = count_only_hmod # demo_resp_count_hmod 
                intr_var_key = "count" # "demo_count", "resp_count"

        pad_id = config["pad_id"]
        bos_id = config["bos_id"]
        eos_id = config["eos_id"]
        dataset, info = datas.make_systematic_dataset(
                n_samples=n_samples,
                min_count=min_count,
                max_count=max_count,
                max_demo_tokens=max_demo_tokens,
                n_demo_types=config.get("n_demo_types", 3),
                multi_trigger=multi_trigger,
                pre_trigger=pre_trigger,
                seq_len=seq_len,
                hold_outs=hold_outs,
                trigger_ids=trigger_ids,
                copy_task=config.get("copy_task", False),
        )
        kwargs = {**config}
        kwargs["pad_id"] = kwargs.get("pad_id", 0)
        kwargs["demo_ids"] = kwargs.get("demo_ids", {4,5,6})
        kwargs["resp_id"] = kwargs.get("resp_id", 3)
        kwargs["eos_id"] = kwargs.get("eos_id", 2)
        kwargs["trigger_id"] = None
        vbls = []
        for seq in dataset:
            vars,_ = run_til_idx(high_model, seq, **kwargs)
            vbls.append(vars)

        step_data = dict()
        for k in vars[0].keys():
            step_data[k] = np.asarray(
                [[v[k] for v in vbl] for vbl in vbls]).reshape(-1)
        step_data["step"] = np.asarray(
            [[i for i in range(len(vbl))] for vbl in vbls]).reshape(-1)
        step_data["token_id"] = np.asarray(dataset).reshape(-1)

        nums = np.asarray(sorted(list(
            set(range(min_count,max_count+1))-hold_outs
        )))
        step_data["n_targs"] = nums.repeat(
            (len(trigger_ids)*n_samples,))[:,None].repeat(
                (len(vbls[0]),))
        step_data["trigger_id"] = np.tile(np.asarray(trigger_ids).repeat(
            (n_samples,)),
            (max_count-min_count+1-len(hold_outs),))[:,None].repeat(
                (len(vbls[0]),))
        m = len(nums)*n_samples*len(trigger_ids)
        step_data["ep_idx"] = np.arange(m)[:,None].repeat(
            (len(vbls[0]),))

        model.eval()
        torch_dataset = torch.LongTensor(dataset)
        with torch.no_grad():
            actvs = collect_activations(
                model=model,
                input_ids=torch_dataset,
                pad_mask=None,
                layers=layers+[output_layer],
                batch_size=None,
                to_cpu=True)

        print("actvs:")
        for k in actvs.keys():
            actvs[k] = actvs[k].reshape(-1,actvs[k].shape[-1])
            print("  ", k, actvs[k].shape)

        step_data["pred_id"] = torch.argmax(
            actvs[output_layer], dim=-1).numpy()

        df = {}
        for k in step_data.keys():
            if k!= "pred_id": df[k] = step_data[k][1:]
            else: df[k] = step_data[k][:-1]
        df = pd.DataFrame(df)
        if "count" in step_data:
            count_key = "count"
        elif "demo_count" in step_data:
            count_key = "demo_count"
        df["targ_count"] = df.groupby(["ep_idx"])[count_key].transform(max)
        df["equal"] = df["pred_id"]==df["token_id"]
        eos_id = config.get("eos_id", 2)
        resp_id = config.get("resp_id", 3)
        targs = np.asarray([eos_id, resp_id])
        acc_df = df.loc[df["token_id"].isin(targs)]
        acc_df = acc_df.groupby(["ep_idx"])["equal"].mean()
        acc = (acc_df.reset_index()["equal"]>=1).mean()
        print("Acc:", acc)
        checkpt["val_correct"] = acc
        savio.save_checkpt(
            save_dict=checkpt,
            save_folder=checkpt["config"]["save_folder"],
            save_name="checkpt",
            epoch=checkpt["epoch"],
            ext=".pt",
            del_prev_sd=False,
        )

        end_name = "validation.csv"
        if not os.path.isdir(model_folder):
            splt = model_folder.split("/")
            model_folder = "/".join(splt[:-1])
            end_name = ".".join(splt[-1].split(".")[:-1])+"_"+end_name
        path = os.path.join(model_folder,end_name)
        df.to_csv(path, header=True, index=False)
        print("Saved csv to", path)

        plot_df = df.loc[df["token_id"].isin(targs)]
        avg = plot_df.groupby(["targ_count"])["equal"].mean()
        fig = plt.figure()
        sns.lineplot(x="targ_count", y="equal", data=plot_df)
        mc = config.get("max_count",20)
        plt.plot([mc,mc],[0,1],"k--")
        plt.ylabel("Accuracy")
        plt.xlabel("Target Count")
        plt.tight_layout()
        path = path.replace(".csv", ".png")
        plt.savefig(path, bbox_inches="tight")
        print("Saved figure to", path)
        #except:
        #    print("Exception Ocurred! Moving on to next model")

