from pathlib import Path
import json
from datetime import datetime
from shutil import copyfile

import torch
from torch import nn
from torch.optim import AdamW, Optimizer
from transformers import AutoTokenizer

from lr_scheduler import CosineScheduler, StableDrop
# from modeling.mamba import MambaConfig, MambaForCausalLM, load_pretrained
from modeling.mamba2.modeling_mamba2_dao import Mamba2ForCausalLM, MambaConfig
from modeling.rwkv5 import RWKV5
from utils import get_non_embed_param_count, get_param_count
from trainer import Trainer
from arguments import Args


def get_tok_path(model_name):
    if model_name.startswith('mamba2-'):
        return 'tokenizers/mamba-tok'
    elif model_name.startswith('rwkv5-'):
        return 'tokenizers/rwkv5-tok'
    else:
        raise ValueError(f"Invalid model name: {model_name}")


def get_tokenizer(args: Args):
    tok_path = get_tok_path(args.model)
    print(f"Loading tokenizer from: {tok_path}")
    tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True)
    return tokenizer


def get_ckpt_path(path):
    if not path.endswith(".pt"):
        # a directory
        path = Path(path)
        checkpoint_files = [file for file in list(path.iterdir()) if file.name.endswith(".pt")]
        assert len(checkpoint_files) == 1, f"None or multiple .pt found in {path}"
        path = path / checkpoint_files[0]
    return Path(path)


def get_model(args: Args, device='cuda') -> nn.Module:
    if args.model in ['mamba', 'mamba1', 'mamba-1']:
        raise NotImplementedError
        print(f"Loading config from {args.model_config}")
        # config = MambaConfig(dtype=args.dtype, scan_impl=args.scan_impl, n_layers=1, d_model=32)
        config = MambaConfig(dtype=args.dtype, scan_impl=args.scan_impl)
        print(f"Loading config: {config}")
        model = MambaForCausalLM(config, chunk_size=args.chunk_size).to(device=device)
        if args.pretrained_path != "":
            print(f"Loading checkpoint from: {args.pretrained_path}")
            load_pretrained(model, get_ckpt_path(args.pretrained_path))
    elif args.model.startswith('mamba2-') or args.model in ['mamba2', 'mamba-2']:
        if bool(args.rand_init):
            print(f"Initting model from {args.model_config}")
            config_data = json.load(open(args.model_config))
            config = MambaConfig(**config_data)
            model = Mamba2ForCausalLM(config).to(device=device, dtype=torch.bfloat16)
            # model = Mamba2ForCausalLM.from_config(args.pretrained_path).to(device=device, dtype=torch.bfloat16)
        else:
            print(f"Loading model from {args.pretrained_path}")
            model = Mamba2ForCausalLM.from_pretrained(args.pretrained_path).to(device=device, dtype=torch.bfloat16)
    elif args.model.startswith('rwkv5-') or args.model in ['rwkv5', 'rwkv-5']:
        print(f"Loading model checkpoint from {args.pretrained_path}")
        model = RWKV5(get_ckpt_path(args.pretrained_path)).to(device=device)
    else:
        raise ValueError(f"Invalid model: {args.model}")

    return model


def get_optimizer(args: Args, model: nn.Module) -> Optimizer:
    optimizer = AdamW(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
    return optimizer


def get_lr_scheduler(args: Args, optimizer: Optimizer):
    # Note that the `end_iter` passed to the LR scheduler specified the
    # total number of optimization, but the actual number of observed batches
    # is `end_iter * grad_accum`. However, the `step` function of the LR
    # scheduler will only be called every `grad_accum` observed batches.
    total_n_steps = args.n_train_steps
    print(f"====  LR scheduler {total_n_steps = } ====")
    if 0 < args.n_warmup_steps < 1:  # 需要支持按固定比例step用来做warmup的
        warmup_iters = int(total_n_steps * args.n_warmup_steps)
    else:
        warmup_iters = int(args.n_warmup_steps)

    if 0 < args.n_drop_steps < 1:  # 需要支持按固定比例step用来做drop的
        drop_iters = int(total_n_steps * args.n_drop_steps)
    else:
        drop_iters = int(args.n_drop_steps)

    if args.lr_scheduler == ["cosine", "cos"]:
        lr_scheduler = CosineScheduler(
            optimizer,
            start_lr=args.lr,
            n_warmup_steps=warmup_iters,
            n_steps=total_n_steps,  # 原来是lr_decay_iter
            cur_step=args.start_step,
            # lr_end_restart=args.lr_end_restart,
            resume_no_optimze=args.resume_no_optimize,
        )
    elif args.lr_scheduler in ["warmupstabledrop", "wsd", "sd", "stabledrop"]:
        lr_scheduler = StableDrop(
            optimizer,
            max_lr=args.lr,
            n_warmup_steps=warmup_iters,
            n_steps=total_n_steps,  # 原来是lr_decay_iter
            n_drop_steps=drop_iters,
            cur_step=args.start_step,
            resume_no_optimze=args.resume_no_optimize,
        )
    return lr_scheduler


def get_run_name(args: Args) -> str:
    # Get the current time
    now = datetime.now()
    # Format the time as YYMMDDhhmmss
    # time_str = now.strftime("%y%m%d%H%M%S")
    run_name = (
        f"{args.model}"
        f"_lr{args.lr}"
        f"_T{args.max_length}"
        f"_B{args.batch_size}"
        f"_GA{args.grad_accum}"
        f"_P{args.packing_count}"
        f"_SR{args.state_reset_interval}"
        f"_RD{args.repeat_data}"
        f"_RI{args.rand_init}"
        # f"_{time_str}"
    )
    return run_name


def main():
    args: Args = Args().parse_args()
    # output_dir = Path(args.output_dir, f"{args.model}-LR{args.lr}-L{args.max_length}-P{args.packing_count}-GA{args.grad_accum}")
    # output_dir.mkdir(exist_ok=True, parents=True)
    print("============ args ================")
    print(args)
    # args.save(output_dir / "args.json")
    tokenizer = get_tokenizer(args)
    model = get_model(args)
    optimizer = get_optimizer(args, model)
    lr_scheduler = get_lr_scheduler(args, optimizer)

    # Save config
    # json.dump(model.config.__dict__, open(output_dir / 'config.json', 'w'), indent=4)

    print("========== Model init finished =========")
    n_params = get_param_count(model)
    n_non_embed_params = get_non_embed_param_count(model)
    print(f"Number of parameter {n_params}, Number of non-e parameter {n_non_embed_params}")
    print("=======================================================")

    # inspect_params(model)
    run_name = get_run_name(args)
    trainer = Trainer(
        args,
        model,
        tokenizer,
        optimizer,
        lr_scheduler,
        run_name,
    )
    trainer.train()


if __name__ == "__main__":
    main()
