import torch
import random
import math
from tqdm.auto import tqdm
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
import os
import itertools

from src.model.moe import MoETransformer
from src.model.lora import LoRATransformer
from src.run.config import RunConfig
from src.run.utils import get_batch, get_select_mask, log_batch_counts, set_seeds
from src.run.logger import get_tqdm_kwargs
from src.run.distributed import get_raw_model, barrier, broadcast_object, is_main_process

def do_routed(
    model: MoETransformer | LoRATransformer,
    config: RunConfig,
    batch_groups: list[dict[str, list[tuple[str, tuple, tuple]] | int]],
    max_active_experts: int = 2,
    state: dict | None = None,
    save_args: dict | None = None,
    optimize_training: bool = True,
) -> tuple[MoETransformer | LoRATransformer, dict]:

    if state is None:
        state = dict()

    if save_args is None:
        save_args = dict()

    if save_args:
        assert "dir" in save_args
        assert "prefix" in save_args
        assert "do_checkpoint" in save_args

    # unpack run config
    acc_steps = config.accumulation_steps
    lr_schedule = config.lr_schedule
    aux_labels = config.aux_labels
    res_dir = config.res_dir
    loaders = config.loaders
    epochs = config.epochs
    logger = config.logger
    lr = config.lr

    model.train()

    class_labels = ["core"] + aux_labels

    # validate batches
    for batches_info in batch_groups:
        batches = batches_info["batches"]
        for l_name, fwd_experts, bck_experts in batches:
            assert l_name in class_labels
            assert all(e in class_labels for e in fwd_experts)
            assert all(e in class_labels for e in bck_experts)

    opts = {}
    losses = {}
    for label in class_labels:

        losses[label] = []
        params = list(get_raw_model(model).get_params(label))
        opts[label] = torch.optim.AdamW(params, lr=lr, fused=True)

        if "opts" in state:
            opts[label].load_state_dict(state["opts"][label])

    total_loss = 0.0
    total_steps = sum(len(x['batches']) for x in batch_groups) * epochs

    cur_lr = lr
    if lr_schedule:

        # Scheduler steps are based on optimizer steps (accounting for gradient accumulation)
        total_micro_steps_all_data = len(loaders['all']['train']) * epochs
        total_opt_steps_all_data = total_micro_steps_all_data // acc_steps
        warmup_steps = round(0.1 * total_opt_steps_all_data)
        
        max_lr = lr
        min_lr = max_lr * 0.1
        start_lr = 1e-8
        start_factor = start_lr / max_lr  # Multiplier to get initial LR of 1e-8
        cur_lr = start_lr

        # LambdaLR scheduler - easy to save/restore, just need last_epoch
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                # LinearLR: factor goes from start_factor to 1.0 over warmup_steps
                return start_factor + (1.0 - start_factor) * (current_step / warmup_steps)
            else:
                # CosineAnnealingLR: from 1.0 to min_lr/max_lr
                min_factor = min_lr / max_lr
                cosine_step = current_step - warmup_steps
                T_max = total_opt_steps_all_data - warmup_steps
                return min_factor + (1.0 - min_factor) * (1 + math.cos(math.pi * cosine_step / T_max)) / 2

        opt = opts["core"]
        scheduler = LambdaLR(opt, lr_lambda)

        # restore scheduler from state
        if "scheduler_epoch" in state and state["scheduler_epoch"] is not None:
            scheduler.last_epoch = state["scheduler_epoch"]
            cur_lr = scheduler.get_lr()[0]  # get_lr() returns absolute LR (base_lr * lambda)
            scheduler._last_lr = [cur_lr]
            for opt in opts.values():
                opt.param_groups[0]['lr'] = cur_lr
            logger.info(f"Restored scheduler to epoch {scheduler.last_epoch}, LR: {cur_lr:.6e}")

    # restore step from state
    resume_step = state.get("step", -1)
    state_total_steps = state.get("total_steps", -1)
    if state_total_steps > 0:
        logger.warning(f"Total steps mismatch: {state_total_steps} != {total_steps}")
            
    num_steps = 0
    for batches_info in batch_groups:

        batches = batches_info["batches"]
        lr_step_freq = batches_info["lr_step_freq"]

        logger.info(f"Batch group start, LR: {cur_lr:.6e}, {len(batches)} batches, lr_step_freq: {lr_step_freq}")

        # total micro-steps for progress bar (each batch repeated acc_steps times)
        num_micro_batches = len(batches) * epochs * acc_steps
        pbar = tqdm(total = num_micro_batches, **get_tqdm_kwargs(logger, ncols=150))
        pbar.refresh()

        # train loop
        for epoch_idx in range(epochs):

            # reset loaders for each epoch
            for label in class_labels:
                loaders[label]["train"].reset(epoch_idx)

            # batch loop
            for batch in batches:

                # get batch info
                loader_name, experts_forward, experts_backward = batch
                loader = loaders[loader_name]["train"]

                # gradient accumulation loop
                for _ in range(acc_steps):

                    # get data
                    x, y, batch_label = get_batch(loader)

                    num_steps += 1
                    pbar.update()
                    if num_steps <= resume_step:
                        continue

                    # forward pass
                    sel_mask = get_select_mask(class_labels, experts_forward, device=x.device)
                    _, loss = model.forward(
                        tokens=x,
                        targets=y,
                        select_mask=sel_mask,
                        optimize=optimize_training
                    )

                    # scale loss for gradient accumulation
                    scaled_loss = loss / acc_steps
                    scaled_loss.backward()

                    # loss logging
                    loss_val = loss.item()
                    losses[batch_label].append(loss_val)
                    total_loss += loss_val

                    # update progress bar
                    exp_for_str = ','.join([e[:2].upper() for e in experts_forward])
                    exp_for_str = exp_for_str.ljust(max_active_experts * 3 - 1)
                    exp_bck_str = ','.join([e[:2].upper() for e in experts_backward])
                    exp_bck_str = exp_bck_str.ljust(max_active_experts * 3 - 1)
                    desc_str = f"LR: {cur_lr:.2e} L: {loss_val:.2f} LB: {batch_label[:4].upper()} EF: {exp_for_str} EB: {exp_bck_str}"

                    pbar.set_description(desc_str)
                    pbar.refresh()

                # skip optimizer/scheduler step if we're still in resume mode
                if num_steps <= resume_step:
                    continue

                # logger printout
                if (num_steps == 1) or (num_steps % 200 == 0) or (num_steps == total_steps):
                    loss_str = ""
                    for label in class_labels:
                        label_str = label[:4].upper() 
                        loss_slice = losses[label][-200:]
                        avg_rolling_loss = np.mean(loss_slice) if len(loss_slice) > 0 else float('nan')
                        loss_str += f"{label_str}: {avg_rolling_loss:.2f} "
                    logger.info(f"Step: {num_steps}, LR: {cur_lr:.2e}, Loss: {loss_str}")

                # step optimizer
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                for exp in experts_backward:
                    opts[exp].step()
                
                # zero grads for all params
                for opt in opts.values():
                    opt.zero_grad(set_to_none=True)

                # update learning rate
                if lr_schedule and (num_steps//acc_steps) % lr_step_freq == 0:
                    scheduler.step()
                    cur_lr = scheduler.get_last_lr()[0]
                    for opt in opts.values():
                        opt.param_groups[0]['lr'] = cur_lr

                # save checkpoint every 30K steps or at final step
                is_checkpoint_step = (num_steps % 30000 == 0)
                is_final_step = num_steps == total_steps
                if save_args.get("do_checkpoint", False) and (is_checkpoint_step or is_final_step):

                    postfix = ""
                    if not is_final_step:
                        postfix = f"_step-{num_steps}"
                
                    if is_main_process():

                        dir = res_dir / save_args["dir"]
                        prefix = save_args["prefix"]

                        os.makedirs(dir, exist_ok=True)
                        checkpoint_path = f"{dir}/{prefix}{postfix}.pth"
                        
                        checkpoint = {
                            'model': get_raw_model(model).state_dict(),
                            'opts': {label: opt.state_dict() for label, opt in opts.items()},
                            'step': num_steps,
                            'total_steps': total_steps,
                            'scheduler_epoch': None,
                        }
                        if lr_schedule:
                            checkpoint['scheduler_epoch'] = scheduler.last_epoch
                        
                        torch.save(checkpoint, checkpoint_path)
                        logger.info(f"Saved checkpoint at step {num_steps} to {checkpoint_path}")

                    # sync after checkpoint
                    barrier()

            # sync after epoch
            barrier()

        # close progress bar
        pbar.close()

    state = {
        'opts': {label: opt.state_dict() for label, opt in opts.items()},
        'step': num_steps,
        'total_steps': total_steps,
        'scheduler_epoch': scheduler.last_epoch if lr_schedule else None,
    }

    return model, state


def do_routed_ordered(
    model: MoETransformer | LoRATransformer,
    config: RunConfig,
    core_prc: float,
    ft_aux_prc: float,
    base_aux_prc: float = 0.0,
    do_ft: bool = True,
    equal_compute: bool = True,
    state: dict | None = None,
    save_args: dict | None = None,
) -> tuple[MoETransformer | LoRATransformer, dict]:
    """
    Train a routed transformer with gradient routing.
        
    Args:
        model: Routed transformer to train
        config: Run configuration
        core_prc: length of base phase as proportion of core-only run length
        ft_aux_prc: proportion of aux during ft
        base_aux_prc: proportion of aux during base
        do_ft: Whether to train on ft batches
        equal_compute: Increase ft phase length to account for higher ordered efficiency
        state: State dictionary
        save_args: Save checkpoint arguments

    Returns:
        tuple[Trained model, state dictionary]
    """

    # unpack run config
    loaders = config.loaders
    logger = config.logger
    aux_labels = config.aux_labels
    acc_steps = config.accumulation_steps
    optimize_training = config.optimize_routed_training

    logger.info("---- Begin Routed Ordered ----")

    set_seeds(config.seed)

    batch_groups: list[list[tuple[str, tuple, tuple]]] = [] # [[ ( label, (params_forward), (params_backward) ), ...], ...]

    len_all = len(loaders["all"]["train"])
    len_all_aux = sum([len(loaders[label]["train"]) for label in aux_labels])
    len_all_core = len(loaders["core"]["train"])
    max_core_prc = len_all / len_all_core

    assert 0 <= core_prc <= max_core_prc, f"0 <= core_prc <= {max_core_prc}"
    assert 0 <= ft_aux_prc <= 1, "0 <= ft_aux_prc <= 1"
    assert 0 <= base_aux_prc <= 1, "0 <= base_aux_prc <= 1"

    len_base = round(len_all_core * core_prc)
    len_ft = len_all - len_base

    if equal_compute:

        raw_model = get_raw_model(model)
        num_core_params = sum(p.numel() for p in raw_model.get_params("core"))
        num_baseline_params = config.num_baseline_params

        #adjust base phase length
        base_factor = num_baseline_params / num_core_params
        logger.info(f"compute-equal base phase: base_factor: {base_factor:.2f}")
        len_base = round(len_base * base_factor)

        #adjust ft phase length
        avg_num_aux_params = 0
        for label in aux_labels:
            avg_num_aux_params += sum(p.numel() for p in raw_model.get_params(label))
        avg_num_aux_params /= len(aux_labels)
        aux_core_ratio = avg_num_aux_params / num_core_params #assumes only one aux active at a time
        ft_factor = base_factor / ( (2/3) + aux_core_ratio )
        logger.info(f"compute-equal ft phase: ft_factor: {ft_factor:.2f}, aux_core_ratio: {aux_core_ratio:.2f}")
        len_ft = round(len_ft * ft_factor)

    len_aux_for_base = round(len_base * base_aux_prc)
    len_core_for_base = len_base - len_aux_for_base

    len_aux_for_ft = round(len_ft * ft_aux_prc)
    len_core_for_ft = len_ft - len_aux_for_ft

    utilized_core = len_core_for_ft + len_core_for_base
    if utilized_core > len_all_core:
        logger.warning(f"utilized_core > len_all_core, ({utilized_core} > {len_all_core}), must resample core")

    utilized_aux = len_aux_for_ft + len_aux_for_base
    if utilized_aux > len_all_aux:
        logger.warning(f"utilized_aux > len_all_aux, ({utilized_aux} > {len_all_aux}), must resample aux")

    base_batches = []
    ft_batches = []

    base_batches += [("core", ("core",), ("core",))] * (len_core_for_base // acc_steps)

    for label in aux_labels:

        ft_batches_temp = []

        cur_aux_len = len(loaders[label]["train"])
        cur_aux_prc = cur_aux_len / len_all_aux

        cur_aux_base_samples = round(cur_aux_prc * len_aux_for_base)
        base_batches += [(label, ("core",), ("core",))] * (cur_aux_base_samples // acc_steps)

        cur_aux_ft_samples = round(cur_aux_prc * len_aux_for_ft)
        ft_batches_temp += [(label, ("core", label), (label,))] * (cur_aux_ft_samples // acc_steps)   

        cur_core_samples = round(cur_aux_prc * len_core_for_ft)
        ft_batches_temp += [("core", ("core", label), (label,))] * (cur_core_samples // acc_steps)

        random.shuffle(ft_batches_temp)
        ft_batches.append(ft_batches_temp)

    logger.info(
        f"\nlen_all: {len_all}, len_all_aux: {len_all_aux}, len_all_core: {len_all_core},"
        f"\nlen_base: {len_base}, len_ft: {len_ft},"
        f"\nlen_core_for_base: {len_core_for_base}, len_aux_for_base: {len_aux_for_base},"
        f"\nlen_core_for_ft: {len_core_for_ft}, len_aux_for_ft: {len_aux_for_ft},"
        f"\ncore_prc: {round(core_prc, 4)}, max_core_prc: {round(max_core_prc, 4)},"
        f"\nft_aux_prc: {round(ft_aux_prc, 4)}, base_aux_prc: {round(base_aux_prc, 4)}"
    )

    random.shuffle(base_batches)
    base_batches = broadcast_object(base_batches, src=0)
    log_batch_counts(base_batches, logger)

    # Interleave sublists: [[a,a,a],[b,b,b,b],[c,c,c]] -> [a,b,c,a,b,c,a,b,c,b]
    ft_batches = [item for tup in itertools.zip_longest(*ft_batches) for item in tup if item is not None]
    ft_batches = broadcast_object(ft_batches, src=0)
    log_batch_counts(ft_batches, logger)

    # --- combine ---

    batch_groups = [{"batches":base_batches, "lr_step_freq":1}]
    if do_ft:
        batch_groups += [{"batches":ft_batches, "lr_step_freq":len(aux_labels)}]

    # --- train ---

    return do_routed(model, config, batch_groups, state=state, save_args=save_args, optimize_training=optimize_training)


def do_routed_unordered(
    model: MoETransformer | LoRATransformer,
    config: RunConfig,
    aux_route_prc: float,
    robust_prc: float,
    core_prc: float = 1.0,
    state: dict | None = None,
    save_args: dict | None = None,
) -> tuple[MoETransformer | LoRATransformer, dict]:
    """
    Train a routed transformer with gradient routing.
        
    Args:
        model: Routed transformer to train
        config: Run configuration
        aux_route_prc: Control leakage of aux data to core
        robust_prc: Control proportion of core batches that are "core robustness" batches (TODO change to core_robust_prc)
        core_prc: Control data proportion in terms of percent of core-only run data (remainder is aux data) 
        checkpoint_dir: Directory to save checkpoints (optional)
        checkpoint_prefix: Prefix for checkpoint filenames

    Returns:
        tuple[Trained model, state dictionary]
    """

    # unpack run config
    aux_labels = config.aux_labels
    loaders = config.loaders
    logger = config.logger
    acc_steps = config.accumulation_steps
    optimize_training = config.optimize_routed_training

    logger.info("---- Begin Routed Unordered ----")

    set_seeds(config.seed)

    assert 0 <= aux_route_prc <= 1, "aux_route_prc must be between 0 and 1"
    assert 0 <= robust_prc <= 1, "robust_prc must be between 0 and 1"
    assert 0 <= core_prc <= 1, "core_prc must be between 0 and 1"

    batches: list[tuple[str, tuple, tuple]] = [] # [( label, (params_forward), (params_backward) ), ...]

    len_all = len(loaders["all"]["train"])
    len_all_aux = sum([len(loaders[label]["train"]) for label in aux_labels])
    len_all_core = len(loaders["core"]["train"])
    len_core = round(len_all_core * core_prc)
    len_aux = len_all - len_core

    # --- core batches (divided by acc_steps, each repeated acc_steps times in do_routed) ---

    core_batches = []

    N = len(aux_labels) + 1
    num_core_robust = round( len_core * robust_prc )
    num_core_non_robust = len_core - num_core_robust
    remainder = num_core_robust % N
    num_core_non_robust += remainder
    num_core_robust -= remainder
    num_core_robust_per_aux = num_core_robust // N
    num_simple_core = num_core_non_robust + num_core_robust_per_aux

    core_batches += [("core", ("core",), ("core",))] * (num_simple_core // acc_steps)
    for label in aux_labels:
        core_batches += [("core", ("core", label), ("core", label))] * (num_core_robust_per_aux // acc_steps)

    log_batch_counts(core_batches, logger)

    # --- aux batches (divided by acc_steps, each repeated acc_steps times in do_routed) ---

    aux_batches = []
    
    for label in aux_labels:

        cur_aux_len = len(loaders[label]["train"])
        cur_aux_prc = cur_aux_len / len_all_aux
        cur_aux_samples = round( cur_aux_prc * len_aux )
        num_aux_routed = round( cur_aux_samples * aux_route_prc )
        num_aux_non_routed = cur_aux_samples - num_aux_routed

        aux_batches += [(label, ("core", label), ("core", label))] * (num_aux_routed // acc_steps)
        aux_batches += [(label, ("core", label), (label,))] * (num_aux_non_routed // acc_steps)

    log_batch_counts(aux_batches, logger)

    # --- combine ---

    batches = core_batches + aux_batches
    len_all = len(loaders["all"]["train"])
    # len_all is the number of batches in the all loader, so we need to divide by acc_steps to get the number of batches
    remainder = len_all // acc_steps - len(batches)
    batches += [("core", ("core",), ("core",))] * (remainder)
    random.shuffle(batches)
    batches = broadcast_object(batches, src=0)
    assert len(batches) == len_all, "number of batches does not match len_all"
    batches = [{"batches":batches, "lr_step_freq":1}]

    # --- train ---

    return do_routed(model, config, batches, state=state, save_args=save_args, optimize_training=optimize_training)


def do_routed_unordered_arbsub(
    model: MoETransformer | LoRATransformer,
    config: RunConfig,
    aux_route_prc: float,
    core_robust_prc: float,
    aux_robust_prc: float,
    state: dict | None = None,
    save_args: dict | None = None,
) -> tuple[MoETransformer | LoRATransformer, dict]:
    """
    Train a routed transformer with gradient routing for arbitrary subset of auxes.

    Args:
        model: Routed transformer to train
        config: Run configuration
        aux_route_prc: Control leakage of aux data to core
        core_robust_prc: Control proportion of core batches that are "core robustness" batches (TODO change to core_robust_prc)
        aux_robust_prc: Control proportion of aux batches that are "aux robustness" batches (TODO change to aux_robust_prc)
        state: State dictionary
        save_args: Save checkpoint arguments

    Returns:
        tuple[Trained model, state dictionary]
    """

    # unpack run config
    aux_labels = config.aux_labels
    loaders = config.loaders
    logger = config.logger
    acc_steps = config.accumulation_steps

    logger.info("---- Begin Routed Unordered Arbitrary Subsets ----")

    set_seeds(config.seed)

    assert 0 <= aux_route_prc <= 1, "aux_route_prc must be between 0 and 1"
    assert 0 <= core_robust_prc <= 1, "core_robust_prc must be between 0 and 1"
    assert 0 <= aux_robust_prc <= 1, "aux_robust_prc must be between 0 and 1"

    batches: list[tuple[str, tuple, tuple]] = [] # [( label, (params_forward), (params_backward) ), ...]

    # --- core batches (divided by acc_steps, each repeated acc_steps times in do_routed) ---

    core_batches = []

    len_core = len(loaders["core"]["train"])
    num_core_robust = round( len_core * core_robust_prc )
    num_core_non_robust = len_core - num_core_robust

    core_batches += [("core", ("core",), ("core",))] * (num_core_non_robust // acc_steps)
    aux_subsets = [
        sorted(subset) for r in range(len(aux_labels) + 1)
        for subset in itertools.combinations(aux_labels, r)
    ]
    random.shuffle(aux_subsets)
    for i in range(num_core_robust // acc_steps):
        aux_subset = aux_subsets[i % len(aux_subsets)]
        participants = tuple(["core"] + aux_subset)
        core_batches += [("core", participants, participants)]

    log_batch_counts(core_batches, logger)

    # --- aux batches (divided by acc_steps, each repeated acc_steps times in do_routed) ---

    aux_batches = []

    for label in aux_labels:

        other_labels = [x for x in aux_labels if x != label]
        other_subsets = [
            sorted(subset) for r in range(len(other_labels) + 1)
            for subset in itertools.combinations(other_labels, r)
        ]
        random.shuffle(other_subsets)

        #--- calcualte subset batch lengths (divided by acc_steps) --

        len_aux = len(loaders[label]["train"])
        num_aux_robust = round( aux_robust_prc * len_aux )
        num_aux_non_robust = len_aux - num_aux_robust

        num_aux_robust_routed = round( num_aux_robust * aux_route_prc )
        num_aux_robust_non_routed = (num_aux_robust - num_aux_robust_routed)

        num_aux_non_robust_routed = round( num_aux_non_robust * aux_route_prc )
        num_aux_non_robust_non_routed = (num_aux_non_robust - num_aux_non_robust_routed)

        #--- make batches ---

        aux_batches += [(label, ("core", label), (label,))] * (num_aux_non_robust_non_routed // acc_steps)
        aux_batches += [(label, ("core", label), ("core", label))] * (num_aux_non_robust_routed // acc_steps)

        for i in range(num_aux_robust_routed // acc_steps):

            other_subset = other_subsets[i % len(other_subsets)]
            participants = tuple(["core", label] + other_subset)
            aux_batches += [(label, participants, ("core", label))]

        for i in range(num_aux_robust_non_routed // acc_steps):

            other_subset = other_subsets[i % len(other_subsets)]
            participants = tuple(["core", label] + other_subset)
            aux_batches += [(label, participants, (label,))]

    log_batch_counts(aux_batches, logger)

    # --- combine ---

    batches = core_batches + aux_batches
    len_all = len(loaders["all"]["train"])
    remainder = len_all // acc_steps - len(batches)
    batches += [("core", ("core",), ("core",))] * (remainder)
    random.shuffle(batches)
    batches = broadcast_object(batches, src=0)
    batches = [{"batches":batches, "lr_step_freq":1}]

    # --- train ---

    return do_routed(model, config, batches, max_active_experts=len(aux_labels)+1, state=state, save_args=save_args, optimize_training=False)