import warnings
warnings.filterwarnings("ignore")
import os
import sys
sys.path.append(os.getcwd())

from optimizers.dct_adamw import CheapLowRankAdamW as DCTAdamW
from optimizers.ldadam import LDAdamW
from optimizers.frugal import GaloreAdamW as FrugalGaloreAdamW
from optimizers.fira import FiraAdamW
from optimizers.galore import GaLoreAdamW
from optimizers.trion import Trion

import argparse
import json
from pathlib import Path
import random
import schedulefree

import numpy as np
import torch
import wandb

import config
from data.utils import DataReader, get_dataset
import distributed
from models.utils import get_model
from optim.optim import train
from optim.utils import cos_inf_schedule, wsd_schedule
import torch.nn as nn

def main(args):
    distributed_backend = distributed.make_backend_from_args(args)
    args = distributed_backend.get_adjusted_args_for_process(args)
    args.world_size = distributed_backend.get_world_size()

    if ('dct-adamw' in args.opt) and (args.lowrank_use_ef == 0) and (args.lowrank_q_ef != 0):
        sys.exit(666)

    if args.full_eval_at is None:
        args.full_eval_at = []

    # NOTE args.seed is offset per worker in get_adjusted_args_for_process
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    if "cuda" in args.device:
        torch.cuda.set_device(torch.device(args.device))
    # torch.use_deterministic_algorithms(True)  # CUBLAS_WORKSPACE_CONFIG=:4096:8

    exp_name = get_exp_name(args, distributed_backend)
    exp_dir = Path(args.results_base_folder) / exp_name
    if distributed_backend.is_master_process() and args.wandb:
        wandb.init(
            project=args.wandb_project,
            group=args.wandb_group,
            job_type=args.wandb_job_type,
            name=exp_name,
            config=vars(args),
        )
        wandb.define_metric("iter")
        wandb.define_metric("train/*", step_metric="iter")
        wandb.define_metric("val/*", step_metric="iter")
        wandb.define_metric("lr", step_metric="iter")

    print(f"Starting Experiment: {exp_name}")
    print(f"Experiment Directory: {exp_dir}")
    print(f"Config:\n{vars(args)}\n")

    print(f"Loading dataset: '{args.dataset}'")
    datareaders = get_data_readers(args)

    model = get_model(args).to(device=args.device)

    # TODO: take care of initializing the model if args.use_pretrained != 'none'
    print(f"\nModel:\n{model}")

    model = distributed_backend.transform_model(model)
    group_specs = distributed_backend.get_raw_model(model).get_parameter_group_specs()
    param_name_mapping = {p_name: p for p_name, p in model.named_parameters()}
    optimized_params_cnt = 0
    for g in group_specs:
        params = []
        for p_name in g["params"]:
            translated_p_names = (
                distributed_backend.translate_model_parameter_name_for_node(p_name)
            )
            params += [param_name_mapping[p_name] for p_name in translated_p_names]
        g["params"] = params
        optimized_params_cnt += sum([p.numel() for p in g["params"]])
    params_cnt = distributed_backend.get_raw_model(model).get_num_params()
    nonemb_param_cnt = (
        params_cnt
        - distributed_backend.get_raw_model(model).lm_head.weight.numel()
        - distributed_backend.get_raw_model(model).transformer.wte.weight.numel()
    )
    print("number of parameters: %.2fM" % (params_cnt / 1e6,))
    print("number of optimized parameters: %.2fM" % (optimized_params_cnt / 1e6,))
    print("number of non-embedding parameters: %.2fM" % (nonemb_param_cnt / 1e6,))
    if args.wandb and distributed_backend.is_master_process():
        wandb.log(
            {
                "parameters": params_cnt,
                "optimized_parameters": optimized_params_cnt,
                "non_embedding_parameters": nonemb_param_cnt,
            }
        )

    if args.opt == "adamw":
        opt = torch.optim.AdamW(
            group_specs,
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
        )
    elif args.opt == "SFAdamW":
        opt = schedulefree.AdamWScheduleFree(
            group_specs,
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            warmup_steps=args.warmup_steps,
        )
    elif "sparse-adamw" in args.opt:
        strategy = args.opt.replace('-sparse-adamw', '')
        opt = BaseSparseAdamW(
            group_specs,
            sparsity=0.5,
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            strategy=strategy,
            use_sparse_grad=bool(args.use_sparse_grad)
        )
    elif args.opt == "blockadam":
        opt = AdamBlockEVLayerWise(
            group_specs,
            block_size=args.blockadam_block_size,
            lr=args.lr,
            beta1=args.beta1,
            beta2=args.beta2,
            weight_decay=args.weight_decay,
            eps=1e-8,
            ev_func=args.blockadam_ev_func,
            use_sqrt=True,
            update_interval=args.blockadam_update_interval)
    elif args.opt == "dct-adamw":
        opt = DCTAdamW(
            group_specs,
            lr=args.lr,
            weight_decay=args.weight_decay,
            rank=args.lowrank_rank,
            proj=args.lowrank_proj,
            use_ef=bool(args.lowrank_use_ef),
            q_ef=args.lowrank_q_ef,
            distributed=True,
            update_proj_gap=args.lowrank_upd_gap,
            rotate_subspace=args.lowrank_rotate_states,
            max_shape=args.lowrank_max_shape,
            betas=(args.beta1, args.beta2),
            eps=1e-8
        )
    elif args.opt in ['ldadamw', 'galoreadamw', 'fira', 'frugal']:
        # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
        lowrank_params = []
        # target_modules_list = ["q_proj", "v_proj", "up_proj", "down_proj", "gate_proj", "k_proj", "o_proj"]
        target_modules_list = ["c_attn", "c_proj", "w1", "w2", ]
        for module_name, module in model.named_modules():
            if not isinstance(module, nn.Linear):
                # print('Is not a Linear module: ', module_name)
                continue

            if not any(target_key in module_name for target_key in target_modules_list):
                print(f'Do not enable GaLore for weights in module {module_name} of shape {module.weight.shape}')
                continue

            print(f'Enable LDAdamW for weights in module {module_name} of shape {module.weight.shape} ')
            lowrank_params.append(module.weight)
        id_lowrank_params = [id(p) for p in lowrank_params]
        # make parameters without "rank" to another group
        regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
        # then call galore_adamw
        if args.opt == 'ldadamw':
            params = [
                {'params': regular_params, 'enable_lowrank': False},
                {'params': lowrank_params, 'enable_lowrank': True},
            ]
            opt = LDAdamW(
                params,
                lr=args.lr,
                betas=(args.beta1, args.beta2),
                eps=1e-8,
                weight_decay=args.weight_decay,
                rank=args.lowrank_rank,
                rho=args.beta1,
                proj_type='std',
                proj_method=args.lowrank_proj,
                error_feedback=args.lowrank_use_ef,
            )
        if args.opt == 'galoreadamw':
            proj_type = args.lowrank_proj
            if proj_type == 'svd':
                proj_type = 'std'
            params = [
                {'params': regular_params},
                {'params': lowrank_params, 'rank': args.lowrank_rank, 'update_proj_gap': args.lowrank_upd_gap, 'scale': 1, 'proj_type': proj_type},
            ]
            opt = GaLoreAdamW(
                params,
                lr=args.lr,
                betas=(args.beta1, args.beta2),
                eps=1e-8,
                weight_decay=args.weight_decay,
            )
        if args.opt == 'fira':
            proj_type = args.lowrank_proj
            if proj_type == 'svd':
                proj_type = 'std'
            params = [
                {'params': regular_params},
                {'params': lowrank_params, 'rank': args.lowrank_rank, 'update_proj_gap': args.lowrank_upd_gap, 'alpha': 1.0, 'proj_type': proj_type}
            ]
            opt = FiraAdamW(params, lr=args.lr)

        if args.opt == 'frugal':
            params = [
                {'params': regular_params, 'is_proj_params': False},
                {'params': lowrank_params, 'is_proj_params': True}
            ]

            opt = FrugalGaloreAdamW(
                params,
                density=args.lowrank_rank,
                proj_type=args.lowrank_proj,
                update_gap=args.lowrank_upd_gap,
                # adam specific
                betas=(args.beta1, args.beta2),
                lr=args.lr,
                weight_decay=args.weight_decay,
                eps=1e-8,
                rotate_states=bool(args.lowrank_rotate_states),
                distributed=bool(args.lowrank_distributed))

        if args.opt in ['trion']:
            params_embedding_lhmead = [p for p in model.parameters() if
                                       p.ndim == 2 and args.vocab_size in tuple(p.shape)]
            params_one_dim = [p for p in model.parameters() if p.ndim == 1]
            params_linear = [p for p in model.parameters() if p.ndim == 2 and args.vocab_size not in tuple(p.shape)]

            param_groups = [
                dict(  # Muon group
                    params=params_linear,
                    use_muon=True,
                    lr=args.lr,
                    weight_decay=args.weight_decay,
                    momentum=0.95,

                    ns_type=args.muon_ns_type,
                    scaling_type=args.scaling_type,
                    rank=args.lowrank_rank,
                ),
                dict(  # AdamW group
                    params=params_embedding_lhmead + params_one_dim,
                    use_muon=False,
                    lr=args.lr,
                    betas=(args.adamw_beta1, args.adamw_beta2),
                    weight_decay=args.weight_decay
                ),
            ]
            opt = Trion(param_groups)
    else:
        opt = torch.optim.SGD(
            group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay
        )
    print(f"\nOptimizer:\n{opt}")

    if args.scheduler != "none":
        assert args.warmup_steps < args.iterations, "Warmup steps must be < iterations."
        if args.scheduler in ["cos", "linear"]:
            # initial lr is args.lr / div_factor
            # final lr is initial_lr/final_div_factor = args.lr / div_factor / final_div_factor
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer=opt,
                max_lr=[group.get("lr", args.lr) for group in group_specs],
                total_steps=args.iterations,
                pct_start=args.warmup_steps / args.iterations,
                anneal_strategy=args.scheduler,
                cycle_momentum=False,
                div_factor=1e2,
                final_div_factor=0.1,
            )
        elif args.scheduler == "cos_inf":
            lambda_schedule = cos_inf_schedule(
                n_iterations=args.iterations,
                n_warmup=args.warmup_steps,
                n_inf=args.cos_inf_steps,
                div_factor=1e2,
                final_div_factor=0.1,
            )
            scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule)
        elif args.scheduler == "wsd":
            lambda_schedule = wsd_schedule(
                n_iterations=args.iterations,
                n_warmup=args.warmup_steps,
                fract_decay=args.wsd_fract_decay,
                init_div_factor=1e2,
                final_lr_factor=args.wsd_final_lr_scale,  # should be 0 here
                decay_type=args.decay_type,
            )
            scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule)
        else:
            raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.")
    else:
        scheduler = None

    if (exp_dir / "ckpts" / "latest" / "main.pt").exists():
        if args.auto_resume:# Auto resume overwrites resume_from
            args.resume_from = str(exp_dir / "ckpts" / "latest")
        else:
            (exp_dir / "ckpts" / "latest" / "main.pt").unlink(missing_ok=True)
            # raise ValueError(
            #     f"The experiment dir {exp_dir} already exists. "
            #     + "To resume training, set auto_resume=True. "
            #     + "Otherwise, specify a different experiment name. "
            # )
    elif distributed_backend.is_master_process():
        exp_dir.mkdir(parents=True, exist_ok=True)

    stats = train(
        model=model,
        opt=opt,
        datareaders=datareaders,
        scheduler=scheduler,
        exp_dir=exp_dir,
        distributed_backend=distributed_backend,
        cfg=args,
    )

    stats["args"] = vars(args)
    if distributed_backend.is_master_process():
        with open(exp_dir / "summary.json", "w") as fs:
            json.dump(stats, fs)
    distributed_backend.finalize()


def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument(
        "--config_format", default="config", choices=config.registered_formats()
    )

    args, rem_args = parser.parse_known_args()

    return config.parse_args_with_format(
        format=args.config_format, base_parser=parser, args=rem_args, namespace=args
    )


def get_exp_name(args, distributed_backend):
    """Returns the name of the experiment, used for saving models and wandb."""
    if args.experiment_name is not None:
        return args.experiment_name

    rank = distributed_backend.rank

    exp_name = (
        f"{args.dataset}_{args.model}_nlayers{args.n_layer}"
        f"_nhead{args.n_head}_lr{args.lr}"
        f"_sched_{args.scheduler}_warmup{args.warmup_steps}"
        f"_decay_{args.decay_type}_{args.wsd_fract_decay}"
        f"_iter{args.iterations}"
        f"_bs{args.batch_size}x{args.acc_steps}_ws{args.world_size}"
    )
    # for mup
    if args.model == "mup_noam":
        exp_name = (
            f"{args.dataset}_{args.model}"
            f"_opt{args.opt}"
            f"_nlayers{args.n_layer}"
            # f"_nhead{args.n_head}"
            f"_lr{args.lr}"
            f"_sched_{args.scheduler}"
            f"_decay_{args.decay_type}"
            # f"_warmup{args.warmup_steps}"
            f"_iter{args.iterations}"
            f"_init{args.init_std}_sce{args.scale_emb}"
            f"_scd{args.scale_depth}"
            # f"_bs{args.batch_size}x{args.acc_steps}_ws{args.world_size}"
        )
    if args.wandb_run_prefix != "none":
        exp_name = args.wandb_run_prefix + "_" + exp_name
    exp_name += f"_seed{args.seed - rank}"
    exp_name += f"_data_seed{args.data_seed}"

    if args.weight_average:
        exp_name += f"_WA"
    if args.opt == "SFAdamW":
        exp_name += f"_beta1_{args.beta1}"
        exp_name += f"_beta2_{args.beta2}"
    return exp_name


def get_data_readers(args, verbose=True):
    data_srcs = get_dataset(args)
    train_reader = DataReader(
        data_src=data_srcs["train"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=True,
        keep_in_ram=args.data_in_ram,
    )
    val_reader = DataReader(
        data_src=data_srcs["val"],
        batch_size=args.batch_size,
        sequence_length=args.sequence_length,
        seed=args.data_seed,
        with_replacement=False,
        auto_shard=False,  # NOTE Identical Per Rank
        keep_in_ram=args.data_in_ram,
    )

    if verbose:
        print(f"Num training tokens: {train_reader.num_tokens}")
        print(f"Num validation tokens: {val_reader.num_tokens}")

    return {
        "train": train_reader,
        "val": val_reader,
    }


if __name__ == "__main__":
    args = get_args()
    main(args)
