import os, sys
import warnings

import torch.distributed

warnings.filterwarnings("ignore")

import argparse
import math, random
import torch
import time

from config import PARAMS_CONFIG
from data import get_train_val_test_data
from models_u import TransformerSeq
from trainer_eval import get_expert_choices
import datetime
from utils import (
    get_params,
    set_up_env,
    get_optimizer_and_scheduler,
    load_checkpoint,
    save_checkpoint,
    create_exp_dir,
    freeze_gate_weight,
    Logger,
    set_freq_optimal_search,
)



def launch(
    env_params,
    model_params,
    adapt_span_params,
    optim_params,
    data_params,
    trainer_params,
):
    # global val
    best_val_loss = None
    # ENVIRONMENT (device, distributed, etc.)
    set_up_env(env_params)
    device = env_params["device"]
    distributed = env_params["distributed"]

    if distributed == False or env_params["rank"] == 0:
        print("data_params:\t", data_params)
        print("model_params:\t", model_params)
        print("optim_params:\t", optim_params)
        print("trainer_params:\t", trainer_params)
        print("adapt_span_params:\t", adapt_span_params)

    # DATA
    train_data, val_data, test_data = get_train_val_test_data(
        data_params=data_params,
        env_params=env_params,
        batch_size=trainer_params["batch_size"],
        device=device,
    )

    # MODEL
    model = TransformerSeq(
        vocab_size=data_params["vocab_size"],
        **model_params,
        adapt_span_params=adapt_span_params,
        # gate_hook=gate_hook,
    )
    print(model)
    if distributed:
        local_rank = env_params["local_rank"]
        model = model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
        )
    else:
        model = torch.nn.DataParallel(model)
        model = model.to(device)

    # OPTIMIZER AND SCHEDULER
    optimizer, scheduler = get_optimizer_and_scheduler(
        model=model, optim_params=optim_params
    )

    # create logger
    logger = Logger()
    # fold_name = trainer_params["checkpoint_path"].split("/")[-1].split(".")[0]
    # folder_path = "/".join(trainer_params["checkpoint_path"].split("/")[:-1])
    # logging = create_exp_dir(f"{folder_path}/experiments/{fold_name}")
    # # log paramters
    # logging(f"Training Parameters:\n {trainer_params}")
    # logging(f"Models Parameters:\n {model_params}")
    # # logging time
    # current_time = datetime.datetime.now()
    # logging(str(current_time))
    # # log model
    # logging(str(model))
    # logging(f"Total of Parameters: {sum(p.numel() for p in model.parameters())}")
    # logging(
    #     f"Total of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    # )
    # resume training from last checkpoint if exists
    iter_init = load_checkpoint(
        trainer_params["checkpoint_path"],
        model,
        optimizer,
        scheduler,
        logger,
        distributed,
    )
    # fix gate
    if model_params["smoe_dropout"]:
        freeze_gate_weight(model)
    # calculate time
    # eval model
    if trainer_params["full_eval_mode"]:
        # evaluate the model on test data
        with torch.no_grad():
            Y, auxiliaries = get_expert_choices(
                model,
                test_data,
                model_params["block_size"],
                model_params["hidden_size"],
            )

        # import pdb; pdb.set_trace()
        expert_choices = [aux[0] for aux in auxiliaries]

            # if distributed:
            #     # collect results into rank0
            #     stats = torch.tensor([loss_val, loss_test]).to(device)
            #     torch.distributed.reduce(stats, 0)
            #     if env_params["rank"] == 0:
            #         loss_val = stats[0] / env_params["world_size"]
            #         loss_test = stats[1] / env_params["world_size"]
            #     else:
            #         return

            # # print('Test BPC: {:.4f}'.format(loss_test / math.log(2)))
            # if ("enwik8" in data_params["data_path"]) or (
            #     "text8" in data_params["data_path"]
            # ):
            #     logging("Val: {:.3f} BPC".format(loss_val / math.log(2)))
            #     logging("Test: {:.3f} BPC".format(loss_test / math.log(2)))
            # else:
            #     logging("Val: {:.3f} PPL".format(math.exp(loss_val)))
            #     logging("Test: {:.3f} PPL".format(math.exp(loss_test)))
        
        print(f'{Y=}')
        # print(f"{len(expert_choices)=}, {expert_choices[0].shape=}")
        torch.save(expert_choices, trainer_params["checkpoint_path"].replace(".pt", "_expert_choices.pt"))


if __name__ == "__main__":
    launch(**get_params(params_config=PARAMS_CONFIG))
