from __future__ import annotations

import pickle
import math
import numpy as np
import torch
import os
from typing import Iterable
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import LambdaLR

from src.model.base import BaseTransformer
from src.run.config import RunConfig
from src.run.utils import get_batch, set_seeds
from src.run.logger import get_tqdm_kwargs
from src.run.dataloader import InterleavedDataLoader
from src.run.distributed import is_main_process, barrier, get_raw_model

def do_train(
    model: BaseTransformer,
    config: RunConfig,
    data_labels: Iterable[str],
    state: dict | None = None,
    save_args: str | None = None,
) -> tuple[BaseTransformer, dict]:
    """
    Train a transformer model on specified data labels.

    Args:
        stage: Training stage info
        model: Model to train
        config: Run configuration
        data_labels: Data labels to train on
        state: State to resume from
        save_args: Arguments to save checkpoints
    
    Returns:
        Trained model (in same wrapped/unwrapped state as input)
    """

    # unpack run config
    acc_steps = config.accumulation_steps
    lr_schedule = config.lr_schedule
    loaders = config.loaders
    res_dir = config.res_dir
    epochs = config.epochs
    logger = config.logger
    lr = config.lr

    if state is None:
        state = dict()

    if save_args is None:
        save_args = dict()

    if save_args:
        assert "dir" in save_args
        assert "prefix" in save_args
        assert "do_checkpoint" in save_args

    logger.info(f"---- Begin Train | Data Labels: {data_labels} ----")

    set_seeds(config.seed)

    data_labels = sorted(data_labels)

    model.train()

    # setup optimizer
    opt = torch.optim.AdamW(model.parameters(), lr=lr, fused=True)
    if "opt" in state and state["opt"] is not None:
        opt.load_state_dict(state["opt"])

    class_labels = ["core"] + config.aux_labels

    losses = {}
    for label in class_labels:
        losses[label] = []

    # setup data loader and calculate total steps
    if len(data_labels) == 1:
        loader = loaders[data_labels[0]]["train"]
    else:
        loader = InterleavedDataLoader([loaders[x]["train"] for x in data_labels], weighted=True)

    num_batches = len(loader)
    total_steps = num_batches * epochs
    pbar = tqdm(total = total_steps, **get_tqdm_kwargs(logger, ncols=150))
    pbar.refresh()

    cur_lr = lr
    if lr_schedule:

        # scheduler steps are based on optimizer steps (accounting for gradient accumulation)
        total_micro_steps_all_data = len(loaders['all']['train']) * epochs
        total_opt_steps_all_data = total_micro_steps_all_data // acc_steps
        warmup_steps = round(0.1 * total_opt_steps_all_data)

        max_lr = lr
        min_lr = max_lr * 0.1
        start_lr = 1e-8
        start_factor = start_lr / max_lr  # Multiplier to get initial LR of 1e-8
        cur_lr = start_lr

        # lambda lr scheduler - easy to save/restore, just need last_epoch
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                # linear warmup from start_factor to 1.0
                return start_factor + (1.0 - start_factor) * (current_step / warmup_steps)
            else:
                # cosine annealing from 1.0 to min_lr/max_lr
                min_factor = min_lr / max_lr
                cosine_step = current_step - warmup_steps
                T_max = total_opt_steps_all_data - warmup_steps
                return min_factor + (1.0 - min_factor) * (1 + math.cos(math.pi * cosine_step / T_max)) / 2

        scheduler = LambdaLR(opt, lr_lambda)

        # restore scheduler from state
        if "scheduler_epoch" in state and state["scheduler_epoch"] is not None:
            scheduler.last_epoch = state["scheduler_epoch"]
            cur_lr = scheduler.get_lr()[0]
            scheduler._last_lr = [cur_lr]
            opt.param_groups[0]['lr'] = cur_lr
            logger.info(f"Restored scheduler to epoch {scheduler.last_epoch}, LR: {cur_lr:.6e}")

    resume_step = state.get("step", -1)
    state_total_steps = state.get("total_steps", -1)
    if state_total_steps > 0:
        logger.warning(f"Total steps mismatch: {state_total_steps} != {total_steps}")
    
    num_steps = 0
    for epoch_idx in range(epochs):

        # reset loader for each epoch
        loader.reset(epoch_idx)

        # batch loop
        for _ in range(num_batches):

            # get batch
            x, y, batch_label = get_batch(loader)

            num_steps += 1
            pbar.update()
            if num_steps <= resume_step:
                continue

            # forward pass
            loss = model.forward(
                tokens=x,
                targets=y,
            )[1]

            # scale loss for gradient accumulation
            scaled_loss = loss / acc_steps
            scaled_loss.backward()

            # loss logging
            loss_val = loss.item()
            losses[batch_label].append(loss_val)

            # update progress bar
            desc_str = f"LR: {cur_lr:.2e} L: {loss_val:.2f} LB: {batch_label[:4].upper()}"
            pbar.set_description(desc_str)
            pbar.refresh()
            
            # logger printout
            if (num_steps == 1) or (num_steps % 200 == 0) or (num_steps == total_steps):
                loss_str = ""
                for label in class_labels:
                    label_str = label[:4].upper()
                    loss_slice = losses[label][-200:]
                    avg_rolling_loss = np.mean(loss_slice) if len(loss_slice) > 0 else float('nan')
                    loss_str += f"{label_str}: {avg_rolling_loss:.2f} "
                logger.info(f"Step: {num_steps}, LR: {cur_lr:.2e}, Loss: {loss_str}")

            # ensure accumulation is complete
            if num_steps % acc_steps == 0:

                # step optimizer
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
                opt.zero_grad(set_to_none=True)

                # update learning rate
                if lr_schedule:
                    scheduler.step()
                    cur_lr = scheduler.get_last_lr()[0]

            # save checkpoint every 30K steps or at final step
            is_checkpoint_step = (num_steps % 30000 == 0)
            is_final_step = num_steps == total_steps
            if save_args.get("do_checkpoint", False) and (is_checkpoint_step or is_final_step):

                postfix = ""
                if not is_final_step:
                    postfix = f"_step-{num_steps}"

                if is_main_process():

                    dir = res_dir / save_args["dir"]
                    prefix = save_args["prefix"]

                    os.makedirs(dir, exist_ok=True)
                    checkpoint_path = f"{dir}/{prefix}{postfix}.pth"
                    
                    checkpoint = {
                        'model': get_raw_model(model).state_dict(),
                        'opt': opt.state_dict(),
                        'step': num_steps,
                        'total_steps': total_steps,
                        'scheduler_epoch': None,
                    }                
                    if lr_schedule:
                        checkpoint['scheduler_epoch'] = scheduler.last_epoch

                    torch.save(checkpoint, checkpoint_path)
                    logger.info(f"Saved checkpoint at step {num_steps} to {checkpoint_path}")

                    loss_dump = {label: np.array(losses[label]) for label in losses}
                    with open(dir / f"losses{postfix}.pkl", "wb") as f:
                        pickle.dump(loss_dump, f)
            
                # sync after checkpoint
                barrier()
            
        # sync after epoch
        barrier()

    # clear any leftover gradients
    if num_steps % acc_steps != 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        opt.zero_grad(set_to_none=True)

    # close progress bar
    pbar.close()

    state = {
        'opt': opt.state_dict(),
        'step': num_steps,
        'total_steps': total_steps,
        'scheduler_epoch': scheduler.last_epoch if lr_schedule else None,
    }

    return model, state