import torch
import numpy as np
import time
from tqdm import tqdm
import os
import sys
try:
    import accelerate
except:
    print("install accelerate")

from dl_utils.save_io import (
    save_checkpt, load_json_or_yaml, record_session, load_init_checkpt,
    load_checkpoint,
)
from dl_utils.schedulers import DecayScheduler
from dl_utils.utils import package_versions

import datas
from datas import get_datasets
from seq_models import make_model
from utils import check_correct_count, print_tensors

"""
This script runs a toy sequence training to ensure that your model
classes are working. The sequence is a starting number that can take
k possible forms, a string of N ordered digits ranging somewhere between
1-100, and a final output of the starting number.
"""

def train(rank, config, verbose=True, *args, **kwargs):
    torch.cuda.empty_cache()

    # Hyperparameters
    config = config_error_catching(config) # Make sure we have valid config
    config["save_folder"] = config.get("save_folder", "./mytraining")
    config["seed"] = config.get("seed", int(time.time()))
    if config["seed"] is None: config["seed"] = int(time.time())
    torch.manual_seed(config["seed"]+rank)
    np.random.seed(   config["seed"]+rank)
    config["rank"] = rank

    # Dataset/Tokenizer
    #######################################
    if verbose and rank==0: print("Making Data")
    # This function updates the config dict and returns DataSet objects
    tokenizer, train_dataset, val_dataset = get_datasets(config)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=config.get("batch_size", 128)
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        shuffle=True,
        batch_size=config.get("batch_size", 1000)
    )
    if verbose and rank==0:
        print("Train Samples:", len(train_dataset))
        print("Val Samples:", len(val_dataset))
        print("Using Sequence Length:", config["seq_len"])

    # Model
    #######################################
    model = make_model(config)
    model = load_init_checkpt(model, config)
    model = load_embeddings(model, config)
    n_params = 0
    for p in model.parameters():
        if hasattr(p, "data"):
            n_params += p.data.numel()
    config["n_params"] = n_params
    print("NParameters:", n_params)

    # Optimizer
    #######################################
    if verbose and rank==0:
        print("Creating Optimizer")
    config["lr"] = config.get("lr", 0.001)
    optimizer = getattr(torch.optim, config.get("optim_type","Adam"))(
        model.parameters(),
        lr=config["lr"],
        weight_decay=config.get("l2", 0),
    )

    # Scheduler
    #######################################
    scheduler = DecayScheduler( optimizer, **config )

    # Distributed Wrapper
    #######################################
    if rank==0 and verbose and torch.cuda.device_count()>1:
        print("Handling multiple GPUs")
    try:
        accelerator = accelerate.Accelerator()
        model, optimizer, train_loader = accelerator.prepare(
            model, optimizer, train_loader
        )
        val_loader = accelerator.prepare(val_loader)
    except:
        print("error with accelerator")

    #############################################################
    # Save Configuration
    #############################################################
    record_session(config, model, globals_dict=globals())

    #############################################################
    # Training
    #############################################################
    n_epochs = config.get("n_epochs", 100)
    best_val_correct = 0
    best_train_correct = 0
    for epoch in range(n_epochs):
        epochtime = time.time()
        torch.cuda.empty_cache()
        if rank==0 and verbose:
            print()
            s = "Beginning Epoch {} - {}".format(
                epoch, config.get("save_folder", "No Save Folder")
            )
            print(s)
            logstr = s + "\n"

        #############################################################
        # Train Loop
        #############################################################
        model.train()
        avg_loss = 0
        avg_acc = 0
        avg_correct = 0
        nloops = config.get("n_train_loops", len(train_loader))
        nloops = min(nloops,len(train_loader))
        checkpt_mod = config.get( "checkpt_mod", np.inf )
        val_mod = config.get( "val_mod", 1)
        optimizer.zero_grad()
        for i,data in enumerate(train_loader):
            starttime = time.time()
            package = model(
                data,
                ret_preds=True,
                tforce=config.get("tforce_train", True),
            )
            loss = package["loss"]
            acc = package["acc"]
            corrects = package["corrects"]

            try:
                accelerator.backward(loss)
            except:
                loss.backward()

            avg_acc += acc.item()
            avg_loss += loss.item()
            avg_correct += corrects.float().mean().item()

            if i%config.get("n_grad_loops",1)==0 or i==len(train_loader)-1:
                if config.get("grad_clip",0) > 0:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config["grad_clip"]
                    )
                optimizer.step()
                optimizer.zero_grad()
                try:
                    scheduler.step()
                except:
                    pass

            if verbose and i%10==0 and rank==0:
                dec = 4
                l = round(loss.item(), dec)
                a = round(acc.item(), dec)
                c = round(100*i/nloops, 2)
                t = round(time.time()-starttime, 3)
                s = "Loss: {} -Acc: {}".format(l,a)
                s += " - {}% {}s   ".format(c,t)
                print(s, end=int(len(s)/2)*" " + "\r")


            if config.get("exp_name","deleteme")=="test" and i>=30: break
            if i>=(nloops-1): break
        div = (i+1)
        dec = 5
        train_loss = round(avg_loss/div, dec)
        train_acc  = round(avg_acc/div, dec)
        train_correct = round(avg_correct/div, dec)
        if verbose:
            s = "Example Train Preds:"
            print()
            print(s)
            logstr += s+"\n"

            print("Token Map:", tokenizer.id2word)
            preds = package["pred_ids"]
            targs = data["output_ids"]
            for i in range(min(3,len(preds))):
                logstr += print_tensors(
                    targ=targs[i],
                    pred=preds[i],
                    mask=data["task_mask"][i],
                    tokenizer=tokenizer)
                print()

            incorrects = ~(corrects==1)
            if incorrects.float().sum()>0:
                s = "Wrong Train Examples:"
                print(s)
                logstr += s+"\n"
                arr = torch.arange(len(incorrects)).long()
                preds = package["pred_ids"]
                targs = data["output_ids"]
                for i in range(min(3,incorrects.long().sum().item())):
                    i = arr[incorrects.cpu()][i]
                    logstr += print_tensors(
                        targ=targs[i],
                        pred=preds[i],
                        mask=data["task_mask"][i],
                        tokenizer=tokenizer)
                    print()


        #############################################################
        # Validation Loop
        #############################################################
        val_loss =     0
        val_acc =      0
        val_correct =  0
        if rank==0 and (epoch%val_mod==0 or epoch==n_epochs-1):
            model.eval()
            if verbose: print("Validating...")
            with torch.no_grad():
                nloops = config.get("max_val_loops",len(val_loader))
                nloops = min(nloops, len(val_loader))
                avg_loss = 0
                avg_acc = 0
                avg_correct = 0
                for i,data in enumerate(val_loader):
                    starttime = time.time()
                    package = model(
                        data,
                        ret_preds=True,
                        tforce=False,
                        temperature=config.get(
                            "sampling_temperature", None
                        )
                    )
                    loss = package["loss"]
                    acc = package["acc"]
                    corrects = package["corrects"]

                    avg_loss += loss.item()
                    avg_acc += acc.item()
                    avg_correct += corrects.float().mean().item()

                    if verbose:
                        p = round(100*(i+1)/nloops, 2)
                        t = round(time.time()-starttime, 4)
                        print("{}% -- {}s".format(p,t), end="         \r")
                    if i>=nloops-1: break
            div = (i+1)
            dec = 5
            val_loss = round(avg_loss/div, 5)
            val_acc =  round(avg_acc/div, 5)
            val_correct =  round(avg_correct/div, 5)
            scheduler.step(val_loss)
            if config.get("exp_name", "deleteme")=="test": break
            if verbose:
                print()
                s = "Example Val Preds:"
                print(s)
                logstr += s+"\n"
                preds = package["pred_ids"]
                targs = data["output_ids"]
                for i in range(min(3,len(preds))):
                    logstr += print_tensors(
                        targ=targs[i],
                        pred=preds[i],
                        mask=data["task_mask"][i],
                        tokenizer=tokenizer)
                    print()
                print()

                s = "Final Stats, Epoch: {}".format(epoch)
                print(s)
                logstr += "\n" + s + "\n"

                s = "Train Loss: {} - Train Acc: {} - Train Correct: {}".format(
                    train_loss,train_acc,train_correct
                )
                logstr += s + "\n"
                print(s)

                s = "Val Loss: {} Val Acc: {} - Val Correct: {}".format(
                    val_loss,val_acc,val_correct)
                logstr += s + "\n"
                print(s)

                s = "Epoch Dur: {}s".format(round(time.time()-epochtime))
                logstr += s + "\n\n\n\n"
                print(s)

                print()
                print()

        ##############################################################
        #### SAVING
        ##############################################################
        save_mod = config.get("sd_save_mod", np.inf)
        if save_mod is None or save_mod<0: save_mod = np.inf
        if rank==0 and (epoch%val_mod==0 or epoch%save_mod==0):
            if config.get( "save", False ):
                save_dict = {
                    "mid_epoch": False,
                    "epoch":       epoch,
                    "train_loss":  train_loss,
                    "train_acc":   train_acc,
                    "train_correct": train_correct,
                    "val_loss":    val_loss,
                    "val_acc":     val_acc,
                    "val_correct": val_correct,
                    "state_dict":  model.state_dict(),
                    "optim_dict":  optimizer.state_dict(),
                    "lr": optimizer.param_groups[0]["lr"],
                    "config":        config,
                }
                # Determine whether to keep the previous save
                keep_prev_sd = save_mod and epoch%save_mod==0

                if keep_prev_sd:
                    # Double the saving increment
                    dmod = config.get("sd_save_double_every", None)
                    double = dmod and epoch%dmod == 0
                    if config.get("save_mod",None) is not None and double:
                        config["sd_save_mod"] *= 2

                best = False
                if train_correct>=best_train_correct-0.001:
                    best_train_correct = train_correct
                    if val_correct>=best_val_correct:
                        best = True
                        best_val_correct = val_correct
                save_checkpt(
                    save_dict=save_dict,
                    save_folder=config["save_folder"],
                    save_name="checkpt",
                    epoch=epoch,
                    ext=".pt",
                    del_prev_sd=not keep_prev_sd,
                    best=best,
                )
                save_training_log(config, logstr)

        # Clean up
        keys = list(package.keys())
        for k in keys: del package[k]
        if config.get("exp_name", "deleteme")=="test" and epoch>2: break
    if verbose:
        print("Ending model", config["save_folder"])
    return model


def save_training_log(config,
                      logstr,
                      fname="training_log.txt",
                      reset=False):
    """
    Saves the logstr to the save folder under the name training_log.txt

    config: dict
    logstr: str
        the string to save
    fname: str
        the name of the file to save to
    reset: bool
        if true, resets the training log and then writes. otherwise
        appends to training log
    """
    mode = "w" if reset else "a"
    with open(os.path.join(config["save_folder"], fname),mode) as f:
        f.write(logstr)

def load_embeddings(model, config):
    """
    This function assists in loading embedding structures from
    pretrained models at the beginning of training.

    Args:
        model: torch module
        config: dict
            A configuration dict that holds the state dict to be loaded.
            "state_dict": torch state dict
    Returns:
        model: torch module
            the model is updated in place, so returning the model is
            actually unnecessary.
    """
    init_checkpt = config.get("init_checkpt", None)
    if init_checkpt is not None and init_checkpt.strip()!="":
        if not os.path.exists(init_checkpt):
            init_checkpt = os.path.join(config["save_root"], init_checkpt)
        checkpt = load_checkpoint(init_checkpt)
        word2id = config["word2id"]
        init_word2id = checkpt["config"]["word2id"]
        print("checkpt:","\n".join([str(k) for k in checkpt["state_dict"].keys()]))
        print()
        print("model:","\n".join([str(k) for k in model.state_dict().keys()]))
        init_emb_weight = checkpt["state_dict"]["model.embeddings.weight"]
        embs = model.model.embeddings
        for word,id_ in init_word2id.items():
            if word in word2id:
                embs.weight.data[word2id[word]] = init_emb_weight[id_]
    return model

def extract_task_config(config):
    """
    This will make it easier to search over task settings using the
    hyperparameter search paradigm in this progrect. 
    """
    task_keys = {
        # Task Agnostic
        "sep_digits", "reverse_digits", "numeral_base",
        # Numeric Equivalence
        "n_demo_types", "chain_of_num", "strategy", "incl_trigger",
        "pre_trigger", "multi_trigger", "max_count", "hold_outs",
        "max_demo_tokens",
        # Arithmetic
        "min_ops", "max_ops", "max_val", "min_val", "max_new",
        "sep_every", "ops", "n_ops",

        # Induction Heads
        "min_resp_dist", "max_resp_dist", "n_trigs", "n_trig_types",
        "n_token_types", "max_first_idx", "trig_first", "incl_sep",
        "allow_dupls",
    }
    task_config = config.get("task_config", dict())
    for k in task_config: task_keys.add(k)
    for k in task_keys:
        if k in config:
            print("Making changes to", k,"!!!!",
                task_config.get(k,None), "->", config[k])
            task_config[k] = config[k]
    if "chain_of_count" in task_config:
        print("Chain of count is deprecated. Using chain_of_num instead!!")
        task_config["chain_of_num"] = task_config["chain_of_count"]
        del task_config["chain_of_count"]
    config["task_config"] = task_config
    return config

def config_error_catching(config):
    """
    This function just makes sure that some obvious hyperparameter
    choices are set and some obviously wrong hyperparameter settings
    are changed to what the experimenter meant.
    """
    config = extract_task_config(config)
    return config

if __name__=="__main__":
    config = { }
    if len(sys.argv)>1:
        config = load_json_or_yaml(sys.argv[1])
    train(0, config)

