import argparse
import collections
import logging
import os
import pathlib
import re
import sys
import time

import accelerate
import numpy as np
import torch.cuda
import yaml
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter

from anGPT.data.data_pipeline import get_data_pipeline
from anGPT.model import anTransformerForCausalLM, anGPTConfig
from anGPT.utils.configuration import Config
from anGPT.utils.configure_optimizer import configure_optimizer
from anGPT.utils.folder_manager import get_experiment_folder


def bold(msg):
    return f"\033[1m{msg}\033[0m"


def normalize_fp32(x, dim=-1):
    dtype = x.dtype
    x = x.float()
    return (x / x.norm(p=2, dim=dim, keepdim=True)).to(dtype=dtype)


def bound_fp32(x, dim=-1):
    dtype = x.dtype
    x = x.float()
    return (x / x.norm(p=2, dim=dim, keepdim=True).clip(min=1.0)).to(dtype=dtype)


def normalize_matrices(model, out_norm_dim_0=True, bounded=False):
    if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module

    if hasattr(model, 'model'):
        model = model.model

    norm = bound_fp32 if bounded else normalize_fp32

    model.transformer.wte.weight.data.copy_(norm(model.transformer.wte.weight.data.clone(), 1))
    model.lm_head.weight.data.copy_(norm(model.lm_head.weight.data.clone(), 1))

    for block in model.transformer["h"]:
        block.Wqkv.weight.data.copy_(norm(block.Wqkv.weight.data.clone(), 1))
        if out_norm_dim_0:
            block.att_c_proj.weight.data.copy_(norm(block.att_c_proj.weight.data.clone(), 0))
        else:
            block.att_c_proj.weight.data.copy_(norm(block.att_c_proj.weight.data.clone(), 1))

        block.c_fc.weight.data.copy_(norm(block.c_fc.weight.data.clone(), 1))
        if out_norm_dim_0:
            block.mlp_c_proj.weight.data.copy_(norm(block.mlp_c_proj.weight.data.clone(), 0))
        else:
            block.mlp_c_proj.weight.data.copy_(norm(block.mlp_c_proj.weight.data.clone(), 1))


def evaluate(model, valid_dl, cfg, logger, tb_logger, step, accelerator, loss_func):
    device = accelerator.device
    is_rank_zero = accelerator.is_main_process
    start_time = time.time()
    model.eval()
    with torch.no_grad():
        acc_loss = torch.tensor(0, device=device, dtype=torch.float)
        num_batches = torch.tensor(0, device=device, dtype=torch.float)
        acc_accuracy = torch.tensor(0, device=device, dtype=torch.float)
        acc_log_probs = torch.tensor(0, device=device, dtype=torch.float)
        acc_count = torch.tensor(0, device=device, dtype=torch.float)
        for i, batch_val in enumerate(valid_dl):
            src_seq = batch_val['src_seq']
            trg_seq = batch_val['trg_seq']
            with accelerator.autocast():
                m_out_val = model(src_seq)
                logits = m_out_val.logits
                val_loss = loss_func(logits.view(-1, logits.size(-1)), trg_seq.view(-1))

            count = torch.sum(batch_val['seqlens'], dtype=torch.float)
            log_probs = val_loss * count
            preds = logits.argmax(dim=-1).view(-1)
            target = batch_val['trg_seq'].view(-1)
            idx = target != cfg.data_base.ignore_index
            accuracy = torch.sum(preds[idx] == target[idx])

            acc_loss += val_loss
            acc_accuracy += accuracy
            acc_log_probs += log_probs
            acc_count += count
            num_batches += 1

        gathered_val_loss = accelerator.gather(acc_loss)
        gathered_num_batches = accelerator.gather(num_batches)
        gathered_acc_log_probs = accelerator.gather(acc_log_probs)
        gathered_acc_accuracy = accelerator.gather(acc_accuracy)
        gathered_acc_count = accelerator.gather(acc_count)
        if is_rank_zero:
            acc_loss = torch.sum(gathered_val_loss)
            num_batches = torch.sum(gathered_num_batches)
            acc_log_probs = torch.sum(gathered_acc_log_probs)
            acc_accuracy = torch.sum(gathered_acc_accuracy)
            acc_count = torch.sum(gathered_acc_count)

            mean_val_loss = acc_loss / num_batches
            ppl = torch.exp(acc_log_probs / acc_count)
            accuracy = acc_accuracy / acc_count

            logger.info(
                f"Validation at step {step} - Mean Loss: {mean_val_loss.item()}"
            )
            tb_logger.add_scalar(f"valid/loss", mean_val_loss.item(), step)
            tb_logger.add_scalar(f"valid/ppl", ppl.item(), step)
            tb_logger.add_scalar(f"valid/accuracy", accuracy.item(), step)

    model.train()
    eval_time = time.time() - start_time
    return eval_time


def main(cfg):
    """
    Launch pretraining
    """

    torch.set_float32_matmul_precision("medium")

    accelerator = accelerate.Accelerator(**cfg.accelerate)
    device = accelerator.device

    world_size = accelerator.num_processes
    is_rank_zero = accelerator.is_main_process
    rank = accelerator.process_index

    accelerate.utils.set_seed(cfg.train.seed)

    if is_rank_zero:
        if cfg.resume.resume_before_annealing:
            has_checkpoint = any(d.startswith("checkpoint_pre_annealing_") for d in os.listdir(cfg.resume.expt_dir))
            assert has_checkpoint
            latest_step = max(
                [int(re.search(r"checkpoint_pre_annealing_(\d+)", d).group(1)) for d in os.listdir(cfg.resume.expt_dir)
                 if d.startswith("checkpoint_pre_annealing_")], )
            exp_folder = pathlib.Path(cfg.resume.expt_dir)
            checkpoint = exp_folder / f"checkpoint_pre_annealing_{latest_step}"
            do_resume_training = True
        else:
            exp_folder = get_experiment_folder(**cfg.experiment, new_folder=is_rank_zero)
            do_resume_training = False
            checkpoint = None
            latest_step = 0
    else:
        exp_folder, do_resume_training, checkpoint, latest_step = None, False, None, 0

    exp_obj = accelerate.utils.broadcast_object_list([exp_folder, do_resume_training, checkpoint, latest_step],
                                                     from_process=0)
    exp_folder, do_resume_training, checkpoint, latest_step = exp_obj

    logging.getLogger().handlers.clear()

    if is_rank_zero:
        logging.basicConfig(
            format=f"[%(asctime)s][rank {rank}][%(levelname)s][%(name)s] - %(message)s",
            datefmt="%d/%m/%Y %H:%M:%S",
            level=logging.INFO,
            handlers=[
                logging.StreamHandler(sys.stdout),
                logging.FileHandler(exp_folder / "logfile.txt"),
            ],
        )
    else:
        logging.basicConfig(
            format=f"[%(asctime)s][rank {rank}][%(levelname)s][%(name)s] - %(message)s",
            datefmt="%d/%m/%Y %H:%M:%S",
            level=logging.INFO,
            handlers=[],
        )
    logger = logging.getLogger(__name__)

    if is_rank_zero:

        logger.info(f"#### Load logger on rank {rank}")
        tb_logger = SummaryWriter(
            log_dir=exp_folder,
        )

        if do_resume_training:
            cfg.save_config(exp_folder, file_name=f"config_resume_after_{latest_step}.yml")
        else:
            cfg.save_config(exp_folder)
        fh = logging.FileHandler(exp_folder / "info.log")
        fh.setLevel(logging.INFO)
        logger.addHandler(fh)

        logger.info(bold("######################################################"))
        logger.info(bold("########          START   TRAINING          ##########"))
        logger.info(bold("######################################################"))

        logger.info(f"########  Project:    {cfg.experiment.project_name}")
        logger.info(f"########  Session:    {cfg.experiment.session_name}")
        logger.info(f"########  Experiment: {cfg.experiment.experiment_name}")
        logger.info(f"save logs and checkpoints in: {exp_folder.as_posix()}")

        logger.info(bold(f"############### ACCELERATE"))
        logger.info(f"World size {world_size}")
        logger.info(f"Rank {rank}")
        logger.info(f"Device {device}")

        logger.info(bold("############### CONFIGURATION"))
        cfg_dict = cfg.get_dict()
        for k, v in cfg_dict.items():
            logger.info(f"{k}: {v}")

    else:
        tb_logger = None

    logger.info(bold(f"############### LOAD DATA on rank {rank}"))

    train_dl, valid_dl, tokenizer = get_data_pipeline(cfg, accelerator.num_processes, accelerator.process_index,
                                                      accelerator.device)
    vocab_size = tokenizer.vocab_size_128

    if is_rank_zero and not do_resume_training:
        tokenizer.save_pretrained(exp_folder / f"tokenizer")

    logger.info(f"vocab size {vocab_size}")

    logger.info(bold(f"############### LOAD MODEL on rank {rank}"))

    hf_config = anGPTConfig(**cfg.model)
    hf_config.vocab_size = vocab_size

    model = anTransformerForCausalLM(hf_config)

    if is_rank_zero:
        def count_parameters(parameters):
            return sum(p.numel() for p in parameters if p.requires_grad)

        model_params = count_parameters(model.parameters())
        logger.info(f"#### trainable_parameters {model_params:,}")
        tb_logger.add_scalar("model/trainable_parameters", model_params, 0)

    model = accelerator.prepare_model(model)
    if cfg.model.use_compile:
        if world_size == 1:
            model.compile_model()
        else:
            model.module.compile_model()

    loss_func = CrossEntropyLoss(ignore_index=cfg.data_base.ignore_index, reduction='mean')
    optimizer, lr_scheduler = configure_optimizer(cfg.optim, model, vocab_size, logger)


    if cfg.train.max_steps is None and cfg.train.max_epochs is None:
        raise ValueError("You must specify either max_steps or max_epochs")
    if cfg.train.max_steps is None:
        cfg.train.max_steps = 1e15
        logger.info(f"Start training for {cfg.train.max_epochs} epochs")
    if cfg.train.max_epochs is None:
        cfg.train.max_epochs = 1e15
        logger.info(f"Start training for {cfg.train.max_steps} steps")

    optimizer = accelerator.prepare_optimizer(optimizer)

    if do_resume_training:
        logger.info(
            bold(f"############### RESUME TRAINING on rank {rank} and step {latest_step} with checkpoint {checkpoint}"))
        accelerator.load_state(checkpoint)

    optimizer.zero_grad()

    RUN_TRAIN = True
    epoch = 0
    step = 0
    optim_step = 0
    total_tokens = 0
    tokens_count = []
    step_times = []

    if cfg.model.explicit_norm:
        normalize_matrices(model, cfg.model.out_norm_dim_0, cfg.model.explicit_norm_bounded)

    while RUN_TRAIN:
        logger.info(f"Start epoch {epoch}")

        model.train()

        if is_rank_zero:
            start_time = time.time()

        for batch in train_dl:

            if not RUN_TRAIN:
                break

            if cfg.train.val_interval > 0 and optim_step != 0 and optim_step % cfg.train.val_interval == 0 and step % cfg.train.grad_accumulation == 0:
                eval_time = evaluate(model, valid_dl, cfg, logger, tb_logger, optim_step + latest_step, accelerator,
                                     loss_func)
                if is_rank_zero:
                    start_time += eval_time

            src_seq = batch['src_seq']
            trg_seq = batch['trg_seq']

            pads = (batch['trg_seq'] == cfg.data_base.ignore_index).sum().item()
            tokens = batch['trg_seq'].numel() - pads
            tokens_count.append(tokens)
            total_tokens += tokens

            with accelerator.autocast():
                m_out = model(src_seq)
                logits = m_out.logits

            loss = loss_func(logits.view(-1, logits.size(-1)), trg_seq.view(-1))
            loss = loss / cfg.train.grad_accumulation
            accelerator.backward(loss)

            if (step + 1) % cfg.train.grad_accumulation == 0:

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(
                        model.parameters(), cfg.train.clip_value
                    )
                optimizer.step()
                if lr_scheduler is not None:
                    lr_scheduler.step()

                if cfg.model.explicit_norm:
                    normalize_matrices(model, cfg.model.out_norm_dim_0, cfg.model.explicit_norm_bounded)

                if is_rank_zero:
                    step_times.append(time.time() - start_time)
                    start_time = time.time()

                gathered_loss = accelerator.gather(loss.detach()) * cfg.train.grad_accumulation

            if optim_step % cfg.train.log_interval == 0 and (
                    step + 1) % cfg.train.grad_accumulation == 0 and is_rank_zero:
                mean_loss = torch.mean(gathered_loss).item()

                log_step = optim_step + latest_step

                logger.info(f"Step {log_step} - Loss: {mean_loss}")

                mean_step_time = np.mean(step_times)
                mean_tokens = np.mean(tokens_count) * world_size
                throughput = mean_tokens / mean_step_time
                step_times = []
                tokens_count = []

                tb_logger.add_scalar("train/mean_step_time", mean_step_time, log_step)
                tb_logger.add_scalar("data/mean_token_per_step", mean_tokens, log_step)
                tb_logger.add_scalar("train/token_throughput_per_sec", throughput, log_step)

                tb_logger.add_scalar("data/batch_size", src_seq.size(0), log_step)
                tb_logger.add_scalar("data/seq_len", src_seq.size(1), log_step)
                tb_logger.add_scalar("data/tokens", total_tokens * world_size, log_step)

                tb_logger.add_scalar("train/loss", mean_loss, log_step)
                tb_logger.add_scalar("train/epoch", epoch, log_step)
                if "lr" in optimizer.param_groups[0]:
                    tb_logger.add_scalar("optim/lr", optimizer.param_groups[0]["lr"], log_step)

            if cfg.train.log_param_interval > 0 and optim_step % cfg.train.log_param_interval == 0 and (
                    step + 1) % cfg.train.grad_accumulation == 0:
                if is_rank_zero:
                    stats = {}
                    for k, v in model.named_parameters():
                        v_detached = v.detach()

                        if torch.isnan(v_detached).sum() > 0:
                            logger.info(f"# NaN in param {k}")
                        if torch.isinf(v_detached).sum() > 0:
                            logger.info(f"# Inf in param {k}")

                        if len(v_detached.shape) == 0:
                            stats[f"param/{k}/scalar"] = v_detached.item()
                            continue

                        stats[f"param/{k}/mean"] = v_detached.mean().item()
                        if v_detached.shape[0] > 1:
                            stats[f"param/{k}/norm"] = torch.norm(v_detached).item()
                            if len(v_detached.shape) > 1:
                                stats[f"param/{k}/norm_dim1"] = torch.norm(v_detached, dim=1).mean().item()

                        if v.grad is not None:
                            if len(v.shape) >= 1 and v.shape[0] > 1:
                                stats[f"grad/{k}/norm"] = torch.norm(v.grad).item()
                                stats[f"effective_lr/{k}"] = (torch.norm(v.grad) / torch.norm(v)).item()
                            else:
                                stats[f"grad/{k}/mean"] = v.grad.mean().item()

                            if cfg.optim.optimizer == 'AdamW':
                                optim_state = optimizer.state[v]
                                stats[f"adam/{k}/exp_avg/mean"] = torch.mean(optim_state['exp_avg']).item()
                                stats[f"adam/{k}/exp_avg_sq/mean"] = torch.mean(optim_state['exp_avg_sq']).item()
                                if len(v.shape) >= 1 and v.shape[0] > 1:
                                    stats[f"adam/{k}/exp_avg/var"] = torch.var(optim_state['exp_avg']).item()
                                    stats[f"adam/{k}/exp_avg_sq/var"] = torch.var(optim_state['exp_avg_sq']).item()

                    for k, v in stats.items():
                        tb_logger.add_scalar(k, v, log_step)

            if (step + 1) % cfg.train.grad_accumulation == 0:
                optimizer.zero_grad()


            if cfg.train.checkpoint_interval != None and cfg.train.checkpoint_interval > 0:
                if optim_step > 0 and optim_step % cfg.train.checkpoint_interval == 0:
                    logger.info(f"### Save checkpoint {optim_step + latest_step}")
                    accelerator.save_state(exp_folder / f"checkpoint_{epoch}_{optim_step + latest_step}")

            if cfg.train.save_model_interval != None and cfg.train.save_model_interval > 0:
                if optim_step % cfg.train.save_model_interval == 0 and optim_step > 0:
                    if accelerator.is_main_process:
                        logger.info(f"### Save model {optim_step + latest_step}")
                        output_dir = exp_folder / f"hf_model_{epoch}_{optim_step + latest_step}"
                        unwrapped_model = accelerate.utils.extract_model_from_parallel(accelerator.unwrap_model(model))
                        unwrapped_model = unwrapped_model.bfloat16()
                        unwrapped_model.save_pretrained(output_dir)
                        if tokenizer:
                            tokenizer.save_pretrained(output_dir)

            if optim_step >= cfg.train.max_steps or epoch > cfg.train.max_epochs:  # -1 becuase starting with 0
                logger.info(f"End training at step {optim_step + latest_step} and epoch {epoch}")
                RUN_TRAIN = False
                break

            if (step + 1) % cfg.train.grad_accumulation == 0:
                optim_step += 1
            step += 1

        logger.info(f"### End epoch {epoch}")
        epoch += 1

    if optim_step % (cfg.train.val_interval - 1) != 0 and optim_step != 0 and cfg.train.val_interval > 0:
        evaluate(model, valid_dl, cfg, logger, tb_logger, optim_step + latest_step, accelerator, loss_func)

    if cfg.train.checkpoint_final:
        logger.info(f"### Save final checkpoint")
        accelerator.save_state(exp_folder / f"checkpoint_annealed_{optim_step + latest_step}")

    if cfg.train.save_final_model:
        logger.info(f"### Save final model")
        output_dir = exp_folder / f"hf_model_annealed_{optim_step + latest_step}"
        unwrapped_model = accelerate.utils.extract_model_from_parallel(accelerator.unwrap_model(model))
        unwrapped_model = unwrapped_model.bfloat16()

        unwrapped_model.save_pretrained(
            output_dir,
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
            state_dict=accelerator.get_state_dict(model),
        )
        if tokenizer and not do_resume_training:
            logger.info(f"### Save final tokenizer")
            tokenizer.save_pretrained(output_dir)

    logger.info(f"Finished saving model weights on rank {rank}")
    logger.info("End training!")


if __name__ == "__main__":
    from functools import reduce
    import operator


    def update(d, u):
        for k, v in u.items():
            if isinstance(v, collections.abc.Mapping):
                d[k] = update(d.get(k, {}), v)
            else:
                d[k] = v
        return d


    def getFromDict(dataDict, mapList):
        return reduce(operator.getitem, mapList, dataDict)


    def setInDict(dataDict, mapList, value):
        getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value


    def convert_string_value(value):
        if value in ("false", "False"):
            value = False
        elif value in ("true", "True"):
            value = True
        elif value == "None":
            value = None
        elif "(" in value and ")" in value:
            value = value[1:-1].split(",")
            value = tuple([convert_string_value(v) for v in value])
        elif "[" in value and "]" in value:
            value = value[1:-1].split(",")
            value = [convert_string_value(v) for v in value]
        elif "{" in value and "}" in value:
            value = value[1:-1].split(",")
            value = {v.split(":")[0]: convert_string_value(v.split(":")[1]) for v in value}
        else:
            try:
                value = int(value)
            except:
                try:
                    value = float(value)
                except:
                    pass
        return value


    default_config_name = "default_config.yaml"

    parser = argparse.ArgumentParser(description="Train Model")
    parser.add_argument(
        "-c", "--config", type=str, default=default_config_name, help="config file name"
    )

    args, unknown_args = parser.parse_known_args()

    config_name = args.config
    if not config_name.endswith(".yaml"):
        config_name += ".yaml"

    config_file = os.path.join("config", args.config)
    with open(config_file, "r") as f:
        config_dict = yaml.load(f, Loader=yaml.Loader)

    for arg in unknown_args:
        if "=" in arg:
            keys = arg.split("=")[0].split(".")
            value = convert_string_value(arg.split("=")[1])
            print(keys, value)
            setInDict(config_dict, keys, value)
        else:
            raise UserWarning(f"argument unknown: {arg}")

    print(config_dict)

    config = Config(config_dict=config_dict)

    main(cfg=config)
