import os
import json
import time
import math
import torch
import torch.nn.functional as F
import argparse
from statistics import mean
from functools import partial
import lm_eval
from tqdm import trange, tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    DataCollatorWithPadding,
    get_scheduler,
    set_seed,
)
from datasets import load_dataset
# from liger_kernel.transformers.functional import liger_cross_entropy

from model import Model


def tokenize(batch, max_length):
    return tokenizer(batch['text'], max_length=max_length, truncation=True)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B")
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--train_batch_size", type=int, default=32)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)

    parser.add_argument("--dataset_name", type=str, default="NAME/iclr-dev")
    parser.add_argument("--dataset_config", type=str, default="people-v1-10sample-textonly")

    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Learning rate",
    )

    parser.add_argument(
        "--weight_decay",
        type=float,
        default=0.01,
        help="L2 weight-decay coefficient (0 → no decay)",
    )
    parser.add_argument("--lr_scheduler_type", type=str, default="constant_with_warmup")
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument("--logging_steps", type=float, default=0.01)
    parser.add_argument("--gpu_flops", type=float, default=312e12)
    parser.add_argument("--max_grad_norm", type=float, default=1.0)
    parser.add_argument("--warmup_ratio", type=float, default=0.1)
    parser.add_argument("--max_length", type=int, default=2048)
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--save_only_last_epoch", action='store_true', default=False)
    args = parser.parse_args()

    set_seed(args.seed)

    device = torch.device('cuda')

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name,
        padding_side='left',
    )

    data = load_dataset(args.dataset_name, args.dataset_config)['train']
    # data = data.select(range(30_000))
    train_dataset = data.map(partial(tokenize, max_length=args.max_length), remove_columns='text', num_proc=16)

    collator = DataCollatorWithPadding(
        tokenizer,
        return_tensors="pt",
        pad_to_multiple_of=64,
    )

    dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        shuffle=True,
        collate_fn=collator,
        pin_memory=True,
        num_workers=2,
    )

    # model = AutoModelForCausalLM.from_pretrained(
    #     args.model_name,
    #     use_cache=False,
    #     # torch_dtype=torch.bfloat16,
    #     torch_dtype='auto',
    #     attn_implementation="flash_attention_2",
    # ).to(0)

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    model = Model.from_pretrained(args.model_name, device='cuda')

    torch.set_float32_matmul_precision('high')
    torch.backends.cuda.matmul.allow_tf32 = True
    # torch._dynamo.config.capture_scalar_outputs = True
    # model = torch.compile(model)

    no_decay = ["bias", "norm"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters,
        lr=args.learning_rate,
        betas=(0.9, 0.95),
        fused=True,
    )
    
    total_train_steps = args.num_train_epochs * (len(dataloader) // args.gradient_accumulation_steps)
    num_warmup_steps = math.ceil(total_train_steps * args.warmup_ratio)
    log_steps = max(math.floor(total_train_steps * args.logging_steps), 1)
    model_num_parameters = sum(p.numel() for _, p in model.named_parameters(remove_duplicate=False))

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=total_train_steps,
    )

    print("***** Running training *****")
    print(f"  Num examples = {len(train_dataset)}")
    print(f"  Num Epochs = {args.num_train_epochs}")
    print(f"  Micro batch size = {args.train_batch_size}")
    print(f"  Total batch size = {args.train_batch_size * args.gradient_accumulation_steps}")
    print(f"  Total optimization steps = {total_train_steps}")
    print(f"  Log steps = {log_steps}")

    total_tokens = 0
    iter_start = time.perf_counter()

    with tqdm(total=total_train_steps) as pbar:
        for epoch in range(1, args.num_train_epochs + 1):
            model.train()

            for step, batch in enumerate(dataloader):
                torch.cuda.empty_cache()
                input_ids = batch['input_ids'].to(device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(device, non_blocking=True)

                loss = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    compute_loss=True,
                ).loss
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                total_tokens += input_ids.shape[0] * input_ids.shape[1]

                if (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == len(dataloader):
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        max_norm=args.max_grad_norm,
                        foreach=True,
                    )
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad(set_to_none=True)
                    pbar.update(1)

                    if (step + 1) % log_steps == 0:
                        torch.cuda.synchronize()
                        tokens_per_second = total_tokens / (time.perf_counter() - iter_start)
                        mfu = (tokens_per_second * 6 * model_num_parameters) / args.gpu_flops
                        lr = lr_scheduler.get_lr()[0]
                        total_tokens = 0
                        iter_start = time.perf_counter()
                        print(f'Loss: {loss.cpu().item():.3f} | Tokens/s: {tokens_per_second:.0f} | MFU: {mfu * 100:.1f}% | Grad norm: {grad_norm:.2f} | LR: {lr:.1e}')
            
            if not args.save_only_last_epoch or (args.save_only_last_epoch and epoch == args.num_train_epochs):
                print(f'Saving epoch {epoch}...')
                model.save_pretrained(f'{args.output_dir}/checkpoint-{epoch}')
                tokenizer.save_pretrained(f'{args.output_dir}/checkpoint-{epoch}')
