import random
import numpy as np
import os
import math
import torch
from typing import Tuple, Literal
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import LambdaLR

from src.model.base import BaseTransformer
from src.run.config import RunConfig
from src.run.utils import get_batch, log_batch_counts, set_seeds
from src.run.logger import get_tqdm_kwargs
from src.run.distributed import is_main_process, barrier, broadcast_object, get_raw_model

def do_coreftaux(
    model: BaseTransformer,
    config: RunConfig,
    do_ft: bool = True,
    aux_prc: float = 0.5,
    core_prc: float = 1.0,
    data_label: str = "core",
    state: dict | None = None,
    save_args: dict | None = None,
) -> Tuple[BaseTransformer, dict]:
    """
    Train a transformer model on specified data label, with datamixing.

    Args:
        phase: Phase of training
        model: Model to train
        config: Run configuration
        data_label: Data label to train on
        state: State dictionary
        save_args: Save checkpoint arguments

    Returns:
        Trained model and metadata
    """

    # unpack run config
    acc_steps = config.accumulation_steps
    lr_schedule = config.lr_schedule
    aux_labels = config.aux_labels
    loaders = config.loaders
    res_dir = config.res_dir
    epochs = config.epochs
    logger = config.logger
    lr = config.lr

    if state is None:
        state = dict()
    
    if save_args is None:
        save_args = dict()

    if save_args:
        assert "dir" in save_args
        assert "prefix" in save_args
        assert "do_checkpoint" in save_args

    logger.info(f"---- Begin CoreFTAux | Data Label: {data_label} ----")

    set_seeds(config.seed)

    model.train()

    class_labels = ["core"] + sorted(aux_labels)

    opt = torch.optim.AdamW(model.parameters(), lr=lr, fused=True)
    if "opt" in state:
        opt.load_state_dict(state["opt"])

    losses = {}
    for label in ["core", data_label]:
        losses[label] = []

    #--- calculate batches ---

    len_all = len(loaders["all"]["train"])
    len_all_aux = sum([len(loaders[label]["train"]) for label in aux_labels])
    len_all_core = len(loaders["core"]["train"])
    max_core_prc = len_all / len_all_core

    assert 0 <= aux_prc <= 1, "0 <= aux_prc <= 1"
    assert 0 <= core_prc <= max_core_prc, f"0 <= core_prc <= {max_core_prc}"

    len_base = round(len_all_core * core_prc)
    len_ft = len_all - len_base

    len_aux = round(len_ft * aux_prc)
    len_core_for_ft = len_ft - len_aux

    utilized_core = len_core_for_ft + len_base
    if utilized_core > len_all_core:
        logger.warning(f"utilized_core > len_all_core, ({utilized_core} > {len_all_core}), must resample core")
    if len_aux > len_all_aux:
        logger.warning(f"len_aux > len_all_aux, ({len_aux} > {len_all_aux}), must resample aux")

    core_batches = ["core"] * len_base

    aux_batches = []
    if do_ft:

        assert data_label != "core", "data_label cannot be 'core' for ft phase"

        cur_aux_len = len(loaders[data_label]["train"])
        cur_aux_prc = cur_aux_len / len_all_aux
        cur_aux_samples = round(cur_aux_prc * len_aux)
        aux_batches += [data_label] * cur_aux_samples

        cur_core_samples = round(cur_aux_prc * len_core_for_ft)
        aux_batches += ["core"] * cur_core_samples

    random.shuffle(aux_batches)
    aux_batches = broadcast_object(aux_batches, src=0)
    
    batches = core_batches + aux_batches
    log_batch_counts(batches, logger)

    logger.info(
        f"len_all: {len_all}, len_all_aux: {len_all_aux}, len_all_core: {len_all_core}, "
        f"len_base: {len_base}, len_ft: {len_ft}, len_aux: {len_aux}, len_core_for_ft: {len_core_for_ft}, "
        f"aux_prc: {aux_prc}, core_prc: {round(core_prc, 4)}, max_core_prc: {round(max_core_prc, 4)}, do_ft: {do_ft}"
    )

    #---- start run ----

    total_steps = len(batches) * epochs
    pbar = tqdm(total = total_steps, **get_tqdm_kwargs(logger, ncols=150))

    if is_main_process(): # Force initial display
        pbar.refresh()

    cur_lr = lr
    if lr_schedule:

        # scheduler steps are based on optimizer steps (accounting for gradient accumulation)
        total_opt_steps_all_data = (len(loaders['all']['train']) * epochs) // acc_steps
        warmup_steps = round(0.1 * total_opt_steps_all_data)

        max_lr = lr
        min_lr = max_lr * 0.1
        start_lr = 1e-8
        start_factor = start_lr / max_lr  # multiplier to get initial LR of 1e-8
        cur_lr = start_lr

        # lambda lr scheduler - easy to save/restore, just need last_epoch
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                # linear warmup from start_factor to 1.0
                return start_factor + (1.0 - start_factor) * (current_step / warmup_steps)
            else:
                # cosine annealing from 1.0 to min_lr/max_lr
                min_factor = min_lr / max_lr
                cosine_step = current_step - warmup_steps
                T_max = total_opt_steps_all_data - warmup_steps
                return min_factor + (1.0 - min_factor) * (1 + math.cos(math.pi * cosine_step / T_max)) / 2

        scheduler = LambdaLR(opt, lr_lambda)

        # restore scheduler from state
        if "scheduler_epoch" in state and state["scheduler_epoch"] is not None:
            scheduler.last_epoch = state["scheduler_epoch"]
            cur_lr = scheduler.get_lr()[0]
            scheduler._last_lr = [cur_lr]
            opt.param_groups[0]['lr'] = cur_lr
            logger.info(f"Restored scheduler to epoch {scheduler.last_epoch}, LR: {cur_lr:.6e}")

    resume_step = state.get("step", -1)
    num_steps = 0
    for epoch_idx in range(epochs):

        # reset loader for each epoch
        for label in class_labels:
            loaders[label]["train"].reset(epoch_idx)

        # batch loop
        for batch_label in batches:

            # get batch
            x, y, _ = get_batch(loaders[batch_label]["train"])

            num_steps += 1
            pbar.update()
            if num_steps <= resume_step:
                continue

            # forward pass
            loss = model.forward(
                tokens=x,
                targets=y,
            )[1]

            # scale loss for gradient accumulation
            scaled_loss = loss / acc_steps
            scaled_loss.backward()

            # loss logging
            loss_val = loss.item()
            losses[batch_label].append(loss_val)

            # update progress bar
            desc_str = f"LR: {cur_lr:.2e} L: {loss_val:.2f} LB: {batch_label[:4].upper()}"
            pbar.set_description(desc_str)
            pbar.refresh()

            # logger printout
            if (num_steps == 1) or (num_steps % 200 == 0) or (num_steps == total_steps):
                loss_str = ""
                for label_name in losses.keys():
                    label_str = label_name[:4].upper()
                    loss_slice = losses[label_name][-200:]
                    avg_rolling_loss = np.mean(loss_slice) if len(loss_slice) > 0 else float('nan')
                    loss_str += f"{label_str}: {avg_rolling_loss:.2f} "
                logger.info(f"Step: {num_steps}, LR: {cur_lr:.2e}, Loss: {loss_str}")

            # ensure accumulation is complete
            if num_steps % acc_steps == 0:

                # step optimizer
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
                opt.zero_grad(set_to_none=True)

                # update learning rate
                if lr_schedule:
                    scheduler.step()
                    cur_lr = scheduler.get_last_lr()[0]

            # save checkpoint every 30K steps or at final step
            is_checkpoint_step = num_steps % 30000 == 0
            is_final_step = num_steps == total_steps
            if save_args.get("do_checkpoint", False) and (is_checkpoint_step or is_final_step):

                postfix = ""
                if not is_final_step:
                    postfix = f"_step-{num_steps}"

                if is_main_process():

                    dir = res_dir / save_args["dir"]
                    prefix = save_args["prefix"]

                    os.makedirs(dir, exist_ok=True)
                    checkpoint_path = f"{dir}/{prefix}{postfix}.pth"
                    
                    checkpoint = {
                        'model': get_raw_model(model).state_dict(),
                        'opt': opt.state_dict(),
                        'step': num_steps,
                        'total_steps': total_steps,
                        'scheduler_epoch': None,
                    }
                    if lr_schedule:
                        checkpoint['scheduler_epoch'] = scheduler.last_epoch

                    torch.save(checkpoint, checkpoint_path)
                    logger.info(f"Saved checkpoint at step {num_steps} to {checkpoint_path}")
            
                # sync after checkpoint
                barrier()

        # sync after epoch
        barrier()

    # clear any leftover gradients
    if num_steps % acc_steps != 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        opt.zero_grad(set_to_none=True)

    # close progress bar
    pbar.close()

    state = {
        "opt": opt.state_dict(),
        "step": num_steps,
        "total_steps": total_steps,
        "scheduler_epoch": scheduler.last_epoch if lr_schedule else None,
    }
    
    return model, state