import os
import warnings

warnings.filterwarnings("ignore")

import math
import torch
import time

from config import PARAMS_CONFIG
from data import get_train_val_test_data
from models import TransformerSeq
from trainer import train_iteration, full_eval
import datetime
import wandb
from utils import (
    AlphaScheduler,
    get_params,
    set_up_env,
    get_optimizer_and_scheduler,
    load_checkpoint,
    save_checkpoint,
    create_exp_dir,
    freeze_gate_weight,
    Logger,
)


def launch(
    env_params,
    model_params,
    adapt_span_params,
    optim_params,
    data_params,
    trainer_params,
    wandb_params,
):
    # global val
    best_val_loss = None
    # ENVIRONMENT (device, distributed, etc.)
    set_up_env(env_params)
    device = env_params["device"]
    distributed = env_params["distributed"]
    resume = trainer_params["resume"]
    
    wandb_flag = wandb_params["wandb_flag"]
    if wandb_flag and (distributed == False or env_params["rank"] == 0):
        wandb.init(project=wandb_params["project_name"], entity=wandb_params["wandb_entity"], name=wandb_params["job_name"])
        wandb.config.update(model_params)

    if distributed == False or env_params["rank"] == 0:
        print("data_params:\t", data_params)
        print("model_params:\t", model_params)
        print("optim_params:\t", optim_params)
        print("trainer_params:\t", trainer_params)
        print("adapt_span_params:\t", adapt_span_params)

    # DATA
    train_data, val_data, test_data = get_train_val_test_data(
        data_params=data_params,
        env_params=env_params,
        batch_size=trainer_params["batch_size"],
        device=device,
    )

    # MODEL
    model = TransformerSeq(
        vocab_size=data_params["vocab_size"],
        **model_params,
        adapt_span_params=adapt_span_params,
    )
    if distributed == False or env_params["rank"] == 0:
        print(model)
    if distributed:
        local_rank = env_params["local_rank"]
        model = model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
        )
    else:
        model = torch.nn.DataParallel(model)
        model = model.to(device)

    # OPTIMIZER AND SCHEDULER
    optimizer, scheduler = get_optimizer_and_scheduler(
        model=model, optim_params=optim_params
    )

    # create logger
    logger = Logger()
    logging = create_exp_dir(trainer_params["checkpoint_path"])
    checkpoint_path = os.path.join(trainer_params["checkpoint_path"], "bests.pt")
    # log paramters
    logging(f"Training Parameters:\n {trainer_params}")
    logging(f"Models Parameters:\n {model_params}")
    # logging time
    current_time = datetime.datetime.now()
    logging(str(current_time))
    # log model
    logging(str(model))
    logging(f"Total of Parameters: {sum(p.numel() for p in model.parameters())}")
    logging(
        f"Total of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )
        
    # resume training from last checkpoint if exists
    iter_init = load_checkpoint(
        checkpoint_path,
        # os.path.join(trainer_params["checkpoint_path"], "40.pt"),
        model,
        optimizer,
        scheduler,
        logger,
        distributed,
        resume,
    )
    # fix gate
    if model_params["smoe_dropout"]:
        freeze_gate_weight(model)
    # calculate time
    start_time = time.time()
    
    # eval model
    if trainer_params["full_eval_mode"]:
        # evaluate the model on test data
        with torch.no_grad():
            loss_val = full_eval(
                model,
                optimizer,
                scheduler,
                val_data,
                model_params["block_size"],
                model_params["hidden_size"],
                gate_alpha=model_params["gate_alpha"],
            )
            loss_test = full_eval(
                model,
                optimizer,
                scheduler,
                test_data,
                model_params["block_size"],
                model_params["hidden_size"],
                gate_alpha=model_params["gate_alpha"],
            )
            if distributed:
                # collect results into rank0
                stats = torch.tensor([loss_val, loss_test]).to(device)
                torch.distributed.reduce(stats, 0)
                if env_params["rank"] == 0:
                    loss_val = stats[0] / env_params["world_size"]
                    loss_test = stats[1] / env_params["world_size"]
                    # print('Test BPC: {:.4f}'.format(loss_test / math.log(2)))
                    if ("enwik8" in data_params["data_path"]) or (
                        "text8" in data_params["data_path"]
                    ):
                        logging("Val: {:.3f} BPC".format(loss_val / math.log(2)))
                        logging("Test: {:.3f} BPC".format(loss_test / math.log(2)))
                    else:
                        logging("Val: {:.3f} PPL".format(math.exp(loss_val)))
                        logging("Test: {:.3f} PPL".format(math.exp(loss_test)))
                else:
                    pass
                
                # sync all process
                torch.distributed.barrier()
                if wandb_flag:
                    if env_params["rank"] == 0:
                        wandb_log_dict = {}            
                        loss_val = loss_val
                        loss_test = loss_test
                        wandb_log_dict.update({'best_valid_ppl':float(math.exp(loss_val)),'best_test_ppl':float(math.exp(loss_test))})
                        wandb.log(wandb_log_dict)
                    else:
                        pass
        return

    # position of current batch
    data_pos = [0] * 2
    # initialize caches for train and valid
    hid_cache = [
        [
            torch.zeros(
                train_data.size(0),
                model.module.layers[layer_i].attn.attn.get_cache_size(),
                model_params["hidden_size"],
            ).to(device)
            for layer_i in range(model.module.attn_layer_count)
        ]
        for _ in range(2)
    ]

    nb_batches_per_iter = trainer_params["nb_batches_per_iter"]
    
    # alpha scheduler for graph gate, using linear decay
    # at test time, gate_alpha is set to target final_alpha, which is gate_alpha in the config
    if model_params["gate_name"] == "graph_global_prevadjinfo":
        alpha_scheduler = AlphaScheduler(init_alpha=1, 
                                        final_alpha=model_params["gate_alpha"], 
                                        end_epoch=int(0.85 * trainer_params["nb_iter"]),
                                        total_epochs=trainer_params["nb_iter"])
    
    for iter_no in range(iter_init, trainer_params["nb_iter"]):
        # time storing
        t_sta = time.time()
        # boolean whether the current iteration is in gate adjustment warmup phase
        is_gate_adj_warmup = iter_no < trainer_params["nb_gate_adj_warmup_iters"]
        # gate_alpha is only used in graph_global_prevadjinfo
        if model_params["gate_name"] == "graph_global_prevadjinfo":
            gate_alpha = alpha_scheduler(iter_no)
            if wandb_flag and env_params["rank"] == 0:
                wandb.log({"gate_alpha":gate_alpha}, commit=False)
        else:
            gate_alpha = None
        loss_train, data_pos[0], hid_cache[0] = train_iteration(
            model,
            model_params["load_balance"],
            optimizer,
            scheduler,
            train_data,
            nb_batches_per_iter,
            model_params["block_size"],
            False,
            data_pos[0],
            hid_cache[0],
            trainer_params["batch_split"],
            is_gate_adj_warmup,
            gate_alpha,
        )
        
        elapsed = 1000 * (time.time() - t_sta) / nb_batches_per_iter
        with torch.no_grad():
            loss_val, data_pos[1], hid_cache[1] = train_iteration(
                model,
                model_params["load_balance"],
                optimizer,
                scheduler,
                val_data,
                nb_batches_per_iter,
                model_params["block_size"],
                True,
                data_pos[1],
                hid_cache[1],
                trainer_params["batch_split"],
                False,
                gate_alpha=gate_alpha,
            )

        if distributed:
            if wandb_flag:
                if env_params["rank"] == 0:
                    wandb_log_dict = {}            
                
                stats = torch.tensor([loss_train, loss_val]).to(device)
                torch.distributed.reduce(stats, 0)
                if env_params["rank"] == 0:
                    loss_train = stats[0] / env_params["world_size"]
                    loss_val = stats[1] / env_params["world_size"]
                    wandb_log_dict.update({'train_ppl':float(math.exp(loss_train)),'Epoch':iter_no,'valid_ppl':float(math.exp(loss_val))})
                    wandb.log(wandb_log_dict)
                else:
                    continue

        logging(f"=================== EPOCHS {iter_no} ======================")
        if ("enwik8" in data_params["data_path"]) or (
            "text8" in data_params["data_path"]
        ):
            msg_result = "Epochs: {} | loss_train: {:.3f} ~ {:.3f} BPC | loss_val: {:.3f} ~ {:.3f} BPC | elapsed: {:.1f}".format(
                iter_no,
                loss_train,
                float(loss_train / math.log(2)),
                loss_val,
                float(loss_val / math.log(2)),
                elapsed,
            )
        else:
            msg_result = "Epochs: {} | loss_train: {:.3f} ~ {:.3f} PPL | loss_val: {:.3f} ~ {:.3f} PPL | elapsed: {:.1f}".format(
                iter_no,
                loss_train,
                float(math.exp(loss_train)),
                loss_val,
                float(math.exp(loss_val)),
                elapsed,
            )
        logging(msg_result)
        # if iter_no % 5 == 0:
        #     save_checkpoint(
        #         os.path.join(trainer_params["checkpoint_path"], f"{iter_no}.pt"),
        #         iter_no,
        #         model,
        #         optimizer,
        #         scheduler,
        #         logger,
        #     )
        
        # Save the model if the validation loss is the best we've seen so far.
        if (best_val_loss is None) or loss_val < best_val_loss:
            best_val_loss = loss_val
            save_checkpoint(
                checkpoint_path,
                iter_no,
                model,
                optimizer,
                scheduler,
                logger,
            )
    end_time = time.time()
    if wandb_flag and (distributed is False or env_params["rank"] == 0):
        wandb.log({"training_time":(end_time - start_time), "best_val_loss":float(math.exp(best_val_loss))})
    logging(f"Training time total: {(end_time - start_time)/3600} h")


if __name__ == "__main__":
    launch(**get_params(params_config=PARAMS_CONFIG))
