import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import sys
import time
import copy

import seq_models as smods
import datas
from datas import (
    get_intervention_dataset, make_systematic_intrv_dataset,
)
from utils import (
    check_correct_count, get_counts,
    run_til_idx, run_for_n_steps, choice,
    pretty_string,
)

import causal_models.samplers as samplers
from causal_models.samplers import *

import causal_models.num_equivalence as numeqv
from causal_models.num_equivalence import *
from automated_utils import automated_das

import dl_utils
from dl_utils.utils import (
    pad_to, get_mask_past_id, num2base, device_fxn, get_datetime_str,
    get_git_revision_hash
)
import dl_utils.save_io as save_io

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

DEVICE = 0 if torch.cuda.is_available() else "cpu"
### TODO: figure out how to optimize.
#print("Using CPU")
#DEVICE = "cpu"

if __name__=="__main__":
    model_folders = []
    og_file_name = None
    command_kwargs = dict()
    overwrite = True
    args = sys.argv[1:]
    for arg in args:
        if save_io.is_model_folder(arg):
            model_folders.append(arg)
        elif save_io.is_exp_folder(arg):
            mfs = save_io.get_model_folders(
                arg, incl_full_path=True, incl_empty=False)
            for f in mfs:
                model_folders.append(f)
        elif "checkpt" in arg and ".pt" in arg:
            model_folders.append(arg)
        elif ".yaml" in arg or ".json" in arg:
            command_kwargs = {**command_kwargs,
                              **save_io.load_json_or_yaml(arg)}
        elif "overwrite=" in arg and arg.split("=")[-1].lower()=="false":
            overwrite = False
        elif "=" in arg:
            key,val = arg.split("=")
            if str(val)=="None":
                val = None
            elif str(val)=="True" or str(val)=="true":
                val = True
            elif str(val)=="False" or str(val)=="false":
                val = False
            command_kwargs[key] = val
        else:
            og_file_name = arg

    # RENAME LAYER TO LAYERS TO AVOID MISTAKES
    if "layer" in command_kwargs:
        print("layer is not a valid key, changing to layers")
        command_kwargs["layers"] = command_kwargs["layer"]
        del command_kwargs["layer"]
    if og_file_name is not None:
        print("Saving to", og_file_name)

    for model_folder in model_folders:
        print()
        print("Beginning", model_folder)
        use_best = command_kwargs.get("use_best", False)
        if use_best:
            print("Using best checkpt instead of last checkpt")
        checkpt = dl_utils.save_io.load_checkpoint(model_folder, use_best=use_best)
        config = checkpt["config"]
        temp = smods.make_model(config)
        temp.load_state_dict(checkpt["state_dict"])
        model = temp.model

        defaults = {
            "layers": None, # Okay to argue a string here instead of list
            # TODO"/mnt/fs0/grantsrb/das_saves/nolnorm_gru/nolnorm_gru_0_seed111111/", # 
            "source_model": None,
            "boundless_das": False,
            "n_neurons": None, #[1,12],
            "hook_type": None,
            "n_epochs": 1000,
            "lr": 0.001,
            "rech_reg": 0.01, # l2 regularization on the final h vector in recurrent rotation settings
            "optim_type": "Adam",
            "batch_size": 512,
            "n_train_samples": 10000,
            "n_val_samples": 1000,
            "early_stopping": 60, # num epochs of no val loss decrease
            "mask_temperature": 0.1,
            "normalize": False, # center DAS data using the mean and std of the training data
            "relaxed": False, # relax the orthonormal constraint
            "double_rot": False, # if true, applies two rot mtxs in a row
            "rot_first": False, # if relaxed, choose to apply rotation or scaling first
            "identity_init": False, # initialize rotation matrix to identity
            "identity_rot": False, # reinitialize the rotation matrix to identity every forward pass (for debugging)
            "rot_bias": False, # use bias in the das rotation

            "min_count": 1,  # config.get("min_count", 1)
            "max_count": 20, # config.get("max_count", 10)
            "max_demo_tokens": config.get("max_demo_tokens", None),
            "unk_p": 0,
            # Can use hold_outs to test generalization capabilities of das
            "hold_outs": None, # {17,18,19,20}

            # The following only apply if systematic data is true
            "systematic_data": False, # optionally pick your training data systematically
            "max_source_count": None, # none defaults to max_count
            "dest_count_step": 1,
            "source_count_step": 3,
            "base_hold_outs": None,
            "source_hold_outs": None,

            "trigger_steps": { 0,1,2 }, # the possible number of steps in the demo after the swap
            "resp_signal_only": True, # only train using response phase signal

            "trigger_id": None,
            "base_sampler":   samplers.any_sampler,
            "source_sampler": samplers.any_sampler,
            "hacky_eos_replacement": False,
        }
        exp_config = copy.deepcopy(command_kwargs)
        mc = int(exp_config.get("max_count", defaults["max_count"]))
        mdt = exp_config.get("max_demo_tokens", defaults["max_demo_tokens"])
        sl = 1+mc*2+3 if not mdt else 1+int(mdt)+mc+3
        defaults["seq_len"] = sl

        ## Single Trigger
        if not config.get("multi_trigger", False):
            defaults["causal_model"] = numeqv.count_only_cmod # demo_resp_count_cmod 
            defaults["intr_var_key"] = "count" # "demo_count", "resp_count"
        ## Multi Pre Trigger
        elif config.get("pre_trigger", True):
            defaults["causal_model"] = numeqv.trigid_count_only_cmod # trigid_demo_resp_count_cmod
            defaults["intr_var_key"] = "count" # "demo_idx", "count", "resp_count", "demo_count"
        ## Multi Post Trigger
        else:
            defaults["causal_model"] = numeqv.trigid_count_only_cmod # trigid_demo_resp_count_cmod
            defaults["intr_var_key"] = "count" # "demo_idx", "count", "resp_count", "demo_count"

        sampler_set = ["source_sampler", "base_sampler", "causal_model"]
        for k in sampler_set:
            if k not in exp_config:
                exp_config[k] = defaults[k]
            if type(exp_config[k])==str:
                exp_config[k] = globals()[exp_config[k]]
            print(k, "-", exp_config[k])

        if og_file_name is None:
            s = set(exp_config.keys())-set(sampler_set)
            file_name = exp_config["causal_model"].__name__+"-"
            if len(s)>0:
                for k in sorted(list(s)):
                    has_len = hasattr(exp_config[k],"__len__")
                    if type(exp_config[k])!=str and has_len:
                        val = "_".join(list([str(e) for e in exp_config[k]])[:5])
                    else:
                        if k=="source_model" and exp_config[k] is not None:
                            m = str(exp_config[k])
                            val = str(dl_utils.save_io.get_exp_num(m))
                            val = m.split("/")[-1].split(f"_{val}_")[0]+val
                        elif k=="layers" and type(exp_config[k])==str:
                            val = "".join([e[:4] for e in exp_config[k].split(".")])
                        else:
                            val = str(exp_config[k])[:5]
                    file_name += k.replace("_","")[:7]+val+"-"
                file_name = file_name[:-1]
            base_sampler = exp_config["base_sampler"]
            source_sampler = exp_config["source_sampler"]
            file_name += "-"
            file_name += source_sampler.__name__.replace("_sampler", "")
            file_name += "2"
            file_name += base_sampler.__name__.replace("_sampler", "")
        else:
            file_name = og_file_name

        if os.path.isdir(model_folder):
            file_name = os.path.join(model_folder, file_name)
        else:
            checkpt = model_folder.split("/")[-1]
            folder = "/".join(model_folder.split("/")[:-1])
            file_name = os.path.join(folder, checkpt+file_name)
        if not overwrite:
            loop = 0
            file_name = file_name+"_"+str(loop)
            while os.path.exists(file_name):
                loop += 1
                file_name = "_".join(file_name.split("_")[:-1]) +"_"+ str(loop)
                print("new fname:", file_name)

        print("Saving to", file_name)
        config["das_save_name"] = file_name.split("/")[-1]
        meta_string = (
            "Description:\n"
        )

        ###################################################################
        for k in defaults:
            if k not in exp_config: exp_config[k] = defaults[k]
            elif type(defaults[k])==list:
                exp_config[k] = [type(defaults[k][0])(exp_config[k])]
            elif k=="hook_type": exp_config[k] = str(exp_config[k])
            elif type(exp_config[k])!=type(defaults[k]):
                if type(defaults[k]) in {int, float, str}:
                    exp_config[k] = type(defaults[k])(exp_config[k])
                elif type(defaults[k])==bool:
                    exp_config[k] = exp_config[k]=="True"
                else:
                    try:
                        exp_config[k] = int(exp_config[k])
                    except:
                        print("failed to convert", k, "with value of",
                            exp_config[k])
        if exp_config["intr_var_key"]=="phase":
            assert not exp_config["resp_signal_only"]
        if exp_config["hold_outs"] is None:
            exp_config["hold_outs"] = []

        if type(exp_config["layers"])==str:
            exp_config["layers"] = {exp_config["layers"]}
        assert exp_config.get("unk_p", 0)<=0
        if exp_config["layers"] is not None:
            if "embeddings" in exp_config["layers"]:
                if config["model_type"] not in {"Transformer"}:
                    exp_config["layers"] = "inpt_identity"
            exists = False
            for layer, _ in model.named_modules():
                if layer in exp_config["layers"] and len(layer.strip())>0:
                    print("Found:", layer)
                    exists = True
            if not exists:
                print("Count not find", exp_config["layers"], "in model")
                continue
            else:
                print("Found", exp_config["layers"])

        for k in exp_config:
            if k not in defaults and k!="use_best":
                print(f"{k} is not a valid key")

        for k in sorted(list(exp_config.keys())):
            print(k, "-", exp_config[k])

        source_model = None
        sconfig = None
        if exp_config["source_model"] is not None:
            sfolder = exp_config["source_model"]
            scheckpt = dl_utils.save_io.load_checkpoint(sfolder)
            sconfig = scheckpt["config"]
            temp = smods.make_model(sconfig)
            temp.load_state_dict(scheckpt["state_dict"])
            source_model = temp.model

        if exp_config["intr_var_key"]=="phase":
            trig_ids = config.get("trigger_ids",[7])
            resp_id = config.get("resp_id",config["demo_ids"][0])
            model.trigger_ids.data = torch.LongTensor(
                [*trig_ids,resp_id])

        n_samples = exp_config["n_train_samples"]

        tconfig = {**config.get("task_config", dict()), **exp_config}
        tconfig["hold_outs"] = {*exp_config["hold_outs"]}
        seq1_kwargs = {
            "task_type": config.get("task_type", "num_equivalence"),
            "task_config": tconfig,
            "idx_sampler": exp_config["base_sampler"],
            "pad_id": config.get("pad_id", 0),
            "bos_id": config.get("bos_id", 1),
            "eos_id": config.get("eos_id", 2),
            "word2id": config.get("word2id", None),
        }
        keys = [
            "seq_len", "trigger_steps",
            "causal_model", "intr_var_key", "trigger_id"
        ]
        for k in keys: seq1_kwargs[k] = exp_config[k]

        val_hold_outs = None
        if seq1_kwargs.get("hold_outs", None) is not None and\
                                len(seq1_kwargs["hold_outs"])>0:
            all_nums = set(range(
                seq1_kwargs["min_count"],
                seq1_kwargs["max_count"]+1))
            val_hold_outs = all_nums - seq1_kwargs["hold_outs"]

        # Train Data
        if sconfig is not None:
            tconfig = {**sconfig["task_config"], **exp_config}
            tconfig["hold_outs"] = {*exp_config["hold_outs"]}
            seq2_kwargs = {
                "task_type": sconfig.get("task_type","num_equivalence"),
                "task_config": tconfig,
                "idx_sampler": exp_config["source_sampler"],
                "pad_id": sconfig.get("pad_id", 0),
                "bos_id": sconfig.get("bos_id", 1),
                "eos_id": sconfig.get("eos_id", 2),
                "word2id": sconfig.get("word2id", None),
            }
            for k in keys: seq2_kwargs[k] = exp_config[k]
        else:
            sconfig = config
            seq2_kwargs = {**seq1_kwargs}
            seq2_kwargs["idx_sampler"] = exp_config["source_sampler"]
        train_data = get_intervention_dataset(
            n_samples=exp_config["n_train_samples"],
            trigger_steps=exp_config["trigger_steps"],
            seq1_kwargs=seq1_kwargs,
            seq2_kwargs=seq2_kwargs,)

        # VAL Data
        n_samples = exp_config["n_val_samples"]
        seq1_kwargs["hold_outs"] = val_hold_outs
        seq2_kwargs["hold_outs"] = val_hold_outs
        seq1_kwargs["seq_len"] = train_data["base_data"].shape[1]
        seq2_kwargs["seq_len"] = train_data["source_data"].shape[1]
        val_data = get_intervention_dataset(
            n_samples=n_samples,
            trigger_steps=exp_config["trigger_steps"],
            seq1_kwargs=seq1_kwargs,
            seq2_kwargs=seq2_kwargs,)

        # TEST Data
        test_data = get_intervention_dataset(
            n_samples=n_samples,
            trigger_steps=exp_config["trigger_steps"],
            seq1_kwargs=seq1_kwargs,
            seq2_kwargs=seq2_kwargs,)

        keys = [
            "base_data","source_data","base_idxs","source_idxs",
            "base_tmasks", "source_tmasks",
        ]
        base_data = dict()
        base_idxs = dict()
        source_data = dict()
        source_idxs = dict()
        names = ["train", "val", "test"]
        data_sets = [train_data, val_data, test_data]
        for name, data in zip(names, data_sets):
            seq1s,seq2s,idx1s,idx2s,tmask1s,tmask2s = [data[k] for k in keys]
            pad_id = config["pad_id"]
            bos_id = config["bos_id"] 
            eos_id = config["eos_id"] 
            base_data[name] = {
              "input_ids": seq1s[:,:-1],
              "pad_mask": (seq1s[:,:-1]==pad_id)|(seq1s[:,:-1]==eos_id),
              "output_ids": seq1s[:,1:],
              "output_pad_mask": (seq1s[:,1:]==pad_id)|(seq1s[:,1:]==bos_id),
              "task_mask": tmask1s.bool(),
            }
            base_idxs[name] =  idx1s
            pad_id = sconfig["pad_id"]
            bos_id = sconfig["bos_id"] 
            eos_id = sconfig["eos_id"] 
            source_data[name] = {
              "input_ids": seq2s[:,:-1],
              "pad_mask": (seq2s[:,:-1]==pad_id)|(seq2s[:,:-1]==eos_id),
              "output_ids": seq2s[:,1:],
              "output_pad_mask": (seq2s[:,1:]==pad_id)|(seq2s[:,1:]==bos_id),
              "task_mask": tmask2s.bool(),
            }
            source_idxs[name] = idx2s
            if exp_config["hacky_eos_replacement"]:
                oids = base_data[name]["output_ids"]
                oids[oids==config["eos_id"]] = config["resp_id"]
                base_data[name]["output_ids"] = oids

        # Verify Trained Model
        model.eval()
        model.to(DEVICE)
        if source_model is not None:
            source_model.eval()
            source_model.to(DEVICE)
            print("Val accuracy will be collected using source model")
        bsize = 1000

        trigger_ids = sconfig.get("trigger_ids", [7])
        for tipe in ["train", "val", "test"]:
            input_ids =      source_data[tipe]["input_ids"]
            output_ids =     source_data[tipe]["output_ids"]
            full_pad_mask =  source_data[tipe]["pad_mask"]
            full_targ_mask = source_data[tipe]["output_pad_mask"]
            task_mask = source_data[tipe]["task_mask"]
            print(tipe, "inpt:", input_ids.shape)
            print(tipe, "outpt:", output_ids.shape)
            avg_acc = 0
            avg_correct = 0
            n_loops = 0
            temp_model = model if source_model is None else source_model
            temp_model.eval()
            with torch.no_grad():
                for b in range(0,input_ids.shape[0],bsize):
                    inpts = input_ids[b:b+bsize].to(DEVICE)
                    pad_mask = full_pad_mask[b:b+bsize].to(DEVICE)
                    targs = output_ids[b:b+bsize].to(DEVICE)
                    targ_mask = full_targ_mask[b:b+bsize].to(DEVICE)
                    tmask = task_mask[b:b+bsize].bool().to(DEVICE)

                    ret_dict = temp_model(
                        inpts=inpts,
                        pad_mask=pad_mask,
                        task_mask=tmask[:,:-1],
                        n_steps=0,
                        tforce=False,)
                    logits =   ret_dict["logits"]
                    pred_ids = ret_dict["pred_ids"]

                    mask = tmask[:,1:]
                    eqs = (pred_ids[mask]==targs[mask])
                    acc = eqs.float().mean()
                    avg_acc += acc.item()

                    corrects = torch.zeros_like(pred_ids).long()
                    corrects[mask] = eqs.long()
                    corrects = corrects.sum(-1)==mask.long().sum(-1)
                    avg_correct += corrects.float().mean().item()
                    n_loops += 1

                ## TODO
                #print("Val Examples:")
                #for i in range(10):
                #    sep = ","
                #    print("Targs:", sep.join(
                #        [" 1"]+["{:2}".format(s) for s in targs[i].tolist()]))
                #    print("Preds:", sep.join(
                #        [" 1"]+["{:2}".format(s) for s in pred_ids[i].tolist()]))
                #    print("Tmask:", sep.join(
                #        ["{:2}".format(s) for s in tmask[i].tolist()]))
                #    print()
                #print("Token Acc:", avg_acc/n_loops)
                #print()
                ## TODO

            s = f"Verification {tipe} Acc: {avg_acc/n_loops}"
            meta_string += "\n" + s
            print(s)
            s = f"Verification {tipe} Correct: {avg_correct/n_loops}"
            meta_string += "\n" + s
            print(s)

        print(f"Saving to {file_name}.csv")

        meta_string += "\n\nModel Config:"
        keys = sorted(list(config.keys()))
        for k in keys:
            meta_string += f"\n  {k}: {config[k]}"

        meta_string += "\n\nDAS Config:"
        keys = sorted(list(exp_config.keys()))
        for k in keys:
            meta_string += f"\n  {k}: {exp_config[k]}"

        meta_string += "\n\n  Seq1 Kwargs:"
        keys = sorted(list(seq1_kwargs.keys()))
        for k in keys:
            meta_string += f"\n    {k}: {seq1_kwargs[k]}"

        if source_model is not None:
            meta_string += "\n\n  Seq2 Kwargs:"
            keys = sorted(list(seq2_kwargs.keys()))
            for k in keys:
                meta_string += f"\n    {k}: {seq2_kwargs[k]}"

        for k in seq1_kwargs:
            if k not in exp_config: exp_config[k] = seq1_kwargs[k]
        save_config = {
            **exp_config,
            "seq1_kwargs": seq1_kwargs,
            "seq2_kwargs": seq2_kwargs,}
        for k,v in exp_config.items():
            if not save_io.is_jsonable(v):
                try:
                    save_config[k] = v.__name__
                except:
                    save_config[k] = ",".join([str(_) for _ in v])
        save_io.save_json(save_config, file_name+"_expconfig.json")

        with open(file_name+".txt", "w") as f:
            f.write(get_datetime_str()+"\n")
            f.write(get_git_revision_hash()+"\n")
            f.write(meta_string)

        kwargs = {
            **exp_config,
            "model": model,
            "source_model": source_model,
            "base_data": base_data,
            "base_idxs": base_idxs,
            "source_data": source_data,
            "source_idxs": source_idxs,
            "info": config,
        }
        intr_modus, metrics = automated_das( **kwargs )
            #layer=exp_config["layers"],
            #boundless_das=exp_config["boundless_das"],
            #n_neurons=exp_config["n_neurons"],
            #resp_signal_only=exp_config["resp_signal_only"],
            #n_epochs=exp_config["n_epochs"],
            #lr=exp_config["lr"],
            #batch_size=exp_config["batch_size"],
            #mask_temperature=exp_config["mask_temperature"],
            #double_rot=exp_config["double_rot"],
            #normalize=exp_config["normalize"],
            #relaxed=exp_config["relaxed"],
            #rot_first=exp_config["rot_first"],
            #identity_init=exp_config["identity_init"],
            #rot_bias=exp_config["rot_bias"],
            #hook_type=exp_config["hook_type"],
            #optim_type=exp_config["optim_type"],
            #rech_reg=exp_config["rech_reg"],
        df = pd.DataFrame(metrics)
        df.to_csv(file_name+".csv", header=True, index=False)
        for n_neurons,modu_dict in intr_modus.items():
            best_modu = modu_dict["best_val"]
            best_modu.cpu()
            intr_modu = modu_dict["last"]
            intr_modu.cpu()
            save_dict = { 
                "config": {
                    "size": intr_modu.rot_mtx.size,
                    "temperature": getattr(intr_modu.swap_module,"temperature"),
                    "fixed": n_neurons,
                    "double_rot": getattr(intr_modu, "double_rot"),
                    "mu": intr_modu.mu,
                    "sigma": intr_modu.sigma,
                    "relaxed": intr_modu.relaxed,
                    "sep_rot": intr_modu.sep_rot,
                },
                "meta_config": exp_config,
                "state_dict": intr_modu.state_dict(),
            }
            try:
                # Last
                path = file_name+f"_intrneurons{n_neurons}.intervene.p"
                torch.save(save_dict, path)
                # Best
                path = file_name+f"_intrneurons{n_neurons}.intervene.best.p"
                save_dict["state_dict"] = best_modu.state_dict()
                save_dict["epoch"] = best_modu.epoch
                torch.save(save_dict, path)
            except:
                print("failed to save interchange module "+str(n_neurons))
        print(df)
        print("End", model_folder)
        print("Saved to", file_name)
        print()




















