from __future__ import annotations

import os
import math
import torch
import random
from tqdm.auto import tqdm
import numpy as np
from torch.optim.lr_scheduler import LambdaLR

from src.model.demix import DemixTransformer
from src.run.config import RunConfig
from src.run.utils import get_batch, get_select_mask, set_seeds
from src.run.logger import get_tqdm_kwargs
from src.run.distributed import get_raw_model, barrier, is_main_process, broadcast_object

def do_demix(
    model: DemixTransformer,
    config: RunConfig,
    state: dict | None = None,
    save_args: dict | None = None,
) -> DemixTransformer:

    # unpack run config
    lr_schedule = config.lr_schedule
    aux_labels = config.aux_labels
    loaders = config.loaders
    res_dir = config.res_dir
    epochs = config.epochs
    logger = config.logger
    lr = config.lr

    if state is None:
        state = dict()

    if save_args is None:
        save_args = dict()

    if save_args:
        assert "dir" in save_args
        assert "prefix" in save_args
        assert "do_checkpoint" in save_args

    logger.info(f"---- Begin Demix ----")
        
    set_seeds(config.seed)

    model.train()

    class_labels = ["core"] + aux_labels

    opts = {}
    losses = {}
    for label in class_labels + ["SHARED"]:

        params = list(get_raw_model(model).get_params(label))
        opts[label] = torch.optim.AdamW(params, lr=lr, fused=True)

        if "opts" in state:
            opts[label].load_state_dict(state["opts"][label])

        if label != "SHARED":
            losses[label] = []

    #--- calculate batches ---
    batches = [] # [label, ...]
    for class_label in class_labels:
        len_class = len(loaders[class_label]["train"])
        batches += [class_label] * len_class
    random.shuffle(batches)
    batches = broadcast_object(batches, src=0)

    total_steps = len(batches) * epochs
    pbar = tqdm(total = total_steps, **get_tqdm_kwargs(logger, ncols=150))
    pbar.refresh()

    cur_lr = lr
    if lr_schedule:

        total_steps_all_data = len(loaders['all']['train']) * epochs
        warmup_steps = round(0.1 * total_steps_all_data)

        max_lr = lr
        min_lr = max_lr * 0.1
        start_lr = 1e-8
        start_factor = start_lr / max_lr  # Multiplier to get initial LR of 1e-8
        cur_lr = start_lr

        # lambda lr scheduler - easy to save/restore, just need last_epoch
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                # linear warmup from start_factor to 1.0
                return start_factor + (1.0 - start_factor) * (current_step / warmup_steps)
            else:
                # cosine annealing from 1.0 to min_lr/max_lr
                min_factor = min_lr / max_lr
                cosine_step = current_step - warmup_steps
                T_max = total_steps_all_data - warmup_steps
                return min_factor + (1.0 - min_factor) * (1 + math.cos(math.pi * cosine_step / T_max)) / 2

        opt = opts["SHARED"]
        scheduler = LambdaLR(opt, lr_lambda)

        # restore scheduler from state
        if "scheduler_epoch" in state and state["scheduler_epoch"] is not None:
            scheduler.last_epoch = state["scheduler_epoch"]
            cur_lr = scheduler.get_lr()[0]
            scheduler._last_lr = [cur_lr]
            for opt in opts.values():
                opt.param_groups[0]['lr'] = cur_lr
            logger.info(f"Restored scheduler to epoch {scheduler.last_epoch}, LR: {cur_lr:.6e}")

    # restore step from state
    resume_step = state.get("step", -1)
    num_steps = 0
    for epoch_idx in range(epochs):

        for label in class_labels:
            loaders[label]["train"].reset(epoch_idx)

        for batch_label in batches:

            # get batch
            loader = loaders[batch_label]["train"]
            x, y, batch_label = get_batch(loader)

            num_steps += 1
            pbar.update()
            if num_steps <= resume_step:
                continue

            # forward pass
            sel_mask = get_select_mask(class_labels, [batch_label], device=x.device)
            _, loss = model.forward(
                tokens=x,
                targets=y,
                select_mask=sel_mask
            )

            loss.backward()

            # loss logging
            loss = loss.item()
            losses[batch_label].append(loss)

            # update progress bar
            desc_str = f"LR: {cur_lr:.2e} L: {loss:.2f} LB: {batch_label[:4].upper()}"
            pbar.set_description(desc_str)
            pbar.refresh()

            # logger printout
            if (num_steps == 1) or (num_steps % 200 == 0) or (num_steps == total_steps):
                loss_str = ""
                for label in class_labels:
                    label_str = label[:4].upper()
                    loss_slice = losses[label][-200:]
                    avg_rolling_loss = np.mean(loss_slice) if len(loss_slice) > 0 else float('nan')
                    loss_str += f"{label_str}: {avg_rolling_loss:.2f} "
                logger.info(f"Step: {num_steps}, LR: {cur_lr:.2e}, Loss: {loss_str}")
           
            # step optimizers
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opts[batch_label].step()
            opts["SHARED"].step()
            
            # zero grads for all params
            for opt in opts.values():
                opt.zero_grad(set_to_none=True)

            # update learning rate
            if lr_schedule:
                scheduler.step()
                cur_lr = scheduler.get_last_lr()[0]
                for opt in opts.values():
                    opt.param_groups[0]['lr'] = cur_lr

            # save checkpoint every 30K steps or at final step
            is_checkpoint_step = num_steps % 30000 == 0
            is_final_step = num_steps == total_steps
            if save_args.get("do_checkpoint", False) and (is_checkpoint_step or is_final_step):

                postfix = ""
                if not is_final_step:
                    postfix = f"_step-{num_steps}"

                if is_main_process():

                    dir = res_dir / save_args["dir"]
                    prefix = save_args["prefix"]

                    os.makedirs(dir, exist_ok=True)
                    checkpoint_path = f"{dir}/{prefix}{postfix}.pth"
                    
                    checkpoint = {
                        'model': get_raw_model(model).state_dict(),
                        'opts': {label: opt.state_dict() for label, opt in opts.items()},
                        'step': num_steps,
                        'total_steps': total_steps,
                        'scheduler_epoch': None,
                    }
                    
                    if lr_schedule:
                        checkpoint['scheduler_epoch'] = scheduler.last_epoch
                    
                    torch.save(checkpoint, checkpoint_path)
                    logger.info(f"Saved checkpoint at step {num_steps} to {checkpoint_path}")

                # sync after checkpoint
                barrier()

        # sync after epoch
        barrier()

    # close progress bar
    pbar.close()

    state = {
        'opts': {label: opt.state_dict() for label, opt in opts.items()},
        'step': num_steps,
        'total_steps': total_steps,
        'scheduler_epoch': scheduler.last_epoch if lr_schedule else None,
    }

    return model, state