import copy
import dataclasses
import io
import json
import os
import shutil
import tempfile
import time
import warnings
from contextlib import nullcontext
from datetime import datetime
from pathlib import Path
from typing import Any

from torch import nn

from mqar.generators import MqarBatch

try:
    from torch.amp import GradScaler
except ImportError:
    pass

import numpy as np
import torch
import wandb
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm

from config import get_wandb_entity, wandb_login
from mqar_zoology.associative_recall import IGNORED_TOKEN
from mqar.grad_flow import init_grad_flow_data, calculate_grad_norm, update_grad_flow_data, format_grad_flow_logs
from utils.common import set_seed

available_mamba_architectures = []
try:
    from mamba_ssm.utils.hf import load_state_dict_hf as mamba_ssm__load_state_dict_hf
    available_mamba_architectures.extend(['mamba_ssm'])
except ImportError:
    # warnings.warn("could not import 'mamba_ssm'")
    pass
available_mamba_architectures.extend(['mamba_tiny'])
torch.autograd.set_detect_anomaly(True)



def _get_save_subdirs(save_dir: Path):

    train_logs_dir = save_dir / "train_logs"
    val_logs_dir = save_dir / "val_logs"
    test_logs_dir = save_dir / "test_logs"

    best_models_dir = save_dir / "best_models"
    model_checkpoints_dir = save_dir / "model_checkpoints"

    return train_logs_dir, val_logs_dir, test_logs_dir, best_models_dir, model_checkpoints_dir


def clean_up(start_datetime_str, verbose=False):
    if verbose:
        print('\nrunning clean up\n')
    tmp = f'./tmp/{start_datetime_str}'
    if os.path.exists(tmp):
        shutil.rmtree(tmp)


def get_available_context(device_name: str, use_amp: bool = False):
    """
    Returns (autocast_ctx, device_ctx).
    If use_amp is False, autocast_ctx is a nullcontext so everything runs in fp32.
    """
    device = torch.device(device_name)
    use_cuda = torch.cuda.is_available() and device.type == 'cuda'
    if not use_cuda:
        return nullcontext(), nullcontext()

    if not use_amp:
        # disable mixed precision: run everything in float32
        return nullcontext(), torch.cuda.device(device)

    # otherwise choose bf16 or fp16 based on SM capability
    major, minor = torch.cuda.get_device_capability(device)
    use_bf16 = (major >= 8)
    dtype = torch.bfloat16 if use_bf16 else torch.float16
    return torch.amp.autocast(device_type='cuda', dtype=dtype), torch.cuda.device(device)


def _copy_state_dict(model):
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}


def atomic_write_bytes(data: bytes, final_path: str):
    final_path = os.path.abspath(final_path)
    d = os.path.dirname(final_path)
    os.makedirs(d, exist_ok=True)

    fd, tmp = tempfile.mkstemp(dir=d, prefix='._', suffix='.tmp')
    with os.fdopen(fd, 'wb') as f:
        f.write(data)
        f.flush()
        os.fsync(f.fileno())
    os.replace(tmp, final_path)


def _save_model_checkpoint(model_state, models_dir, name, verbose=False):
    if verbose:
        print(f"saving {name} model checkpoint...")
    os.makedirs(models_dir, exist_ok=True)
    buf = io.BytesIO()
    torch.save(model_state, buf)
    atomic_write_bytes(buf.getvalue(), os.path.join(models_dir, f'{name}.pt'))
    if verbose:
        print("saved")
        
def _save_run_config(run_config, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    atomic_write_bytes(json.dumps(run_config, indent=2).encode(), os.path.join(save_dir, 'run_config.json'))

def _save_logs(logs: dict[str, Any], logs_dir: Path, name):

    os.makedirs(logs_dir, exist_ok=True)
    output_file = logs_dir / f'{name}.json'

    # print(f"saving logs to {output_file}")

    # write with non-root permissions
    os.umask(0o022)  # or 0o002 if using a shared group
    atomic_write_bytes(json.dumps(logs, indent=2).encode(), str(output_file))
    os.chmod(output_file, 0o644)  # or 0o664 for group-read/write

def _safe_path_name(name: str, default='default') -> str:
    s = ''.join(c if c.isalnum() else '_' for c in name)   # non-alnum -> _
    s = '_'.join(part for part in s.split('_') if part)    # collapse repeats
    s = s.strip('_') or default
    return s


def run_train_loop(
        model: nn.Module,
        dataloaders: dict[str, DataLoader],
        run_config: dict[str, Any],
        debug_text: str = None,
) -> dict[str, Any]:

    runtime_config = run_config['runtime']
    train_config = run_config['training']
    wandb_config = run_config['wandb']

    save_dir = Path(run_config["io"]["run_results_dir"])
    safe_run_name = _safe_path_name(wandb_config['run_name'])

    set_seed(runtime_config['seed'], verbose=True)

    model.to(runtime_config['device'])

    start_str = datetime.now().strftime("%Y_%m_%d__%H_%M_%S")

    label_smoothing = train_config.get('label_smoothing', 0)
    loss_fn = CrossEntropyLoss(ignore_index=IGNORED_TOKEN, label_smoothing=label_smoothing)

    # dirs
    train_logs_dir, val_logs_dir, test_logs_dir, best_models_dir, model_checkpoints_dir = _get_save_subdirs(save_dir)

    # ---

    model.eval()

    desc_text = (f"{debug_text} | " if debug_text is not None else "")

    initial_val_log, _ = evaluate_split(
        model, dataloaders['val'], run_config,
        loss_fn=loss_fn, step=0, desc_text=desc_text,
    )
    val_accuracy = initial_val_log['accuracy']

    if (not runtime_config['should_train']) or (val_accuracy > train_config['threshold_accuracy']):
        _save_logs(logs=initial_val_log, logs_dir=val_logs_dir, name=safe_run_name)
        _save_logs(logs=initial_val_log, logs_dir=train_logs_dir, name=safe_run_name)  # todo
        return initial_val_log

    # prepare for training

    # build optimizer
    optimizer_config = train_config['optimizer']
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=optimizer_config['learning_rate'],
        betas=optimizer_config['betas'],
        weight_decay=optimizer_config['weight_decay'],
    )

    # build scheduler from config if provided
    if (scheduler_config := train_config.get('scheduler')) is not None:
        scheduler = _construct_scheduler_from_config(optimizer, scheduler_config)
    else:
        scheduler = None

    # scaler (for AMP)
    if train_config.get('use_amp', False):
        scaler = GradScaler()
    else:
        scaler = None

    # wandb
    if wandb_config['activate']:

        os.environ['WANDB_CACHE_DIR'] = wandb_config['cache_dir']

        if not wandb_config['verbose']:
            os.environ["WANDB_SILENT"] = "true"  # suppresses wandb console output
            os.environ["WANDB_CONSOLE"] = "off"  # don't wrap/redirect console

        wandb_login()

        # if a run is already active in this process, close it first
        if wandb.run is not None:
            wandb.finish()

        # init
        wandb_run = wandb.init(
            project=wandb_config['project_name'],
            group=wandb_config['group_name'],
            name=wandb_config['run_name'],
            dir=wandb_config['output_dir'],
            entity=get_wandb_entity(),
            config=run_config,
            reinit='finish_previous',
        )

    else:
        wandb_run = None

    # train model
    best_model, results = _train_model(
        model, dataloaders,
        optimizer, scheduler, scaler,
        loss_fn, run_config,
        wandb_run=wandb_run,
        desc_text=desc_text,
    )

    wandb.finish()
    clean_up(start_str)

    return results


def wandb_log(logs: dict, step: int, name: str = None, commit: bool = True):

    if name is not None:
        named_logs = {f"{k} ({name})": v for k, v in logs.items()}
    else:
        named_logs = logs

    wandb.log(named_logs, step=step, commit=commit)


def _train_model(
        model, dataloaders,
        optimizer, scheduler, scaler,
        loss_fn,
        run_config,
        wandb_run=None,
        desc_text=None,
):
    runtime_config = run_config['runtime']
    train_config = run_config['training']
    wandb_config = run_config['wandb']

    device = runtime_config['device']

    save_dir = Path(run_config["io"]["run_results_dir"])
    use_amp = train_config.get('use_amp', False)

    # gradient accumulation setup
    accum_steps = int(train_config.get('grad_accum_steps', 1))
    logs_save_steps = int(train_config.get("save_logs_any_num_steps", None))
    assert accum_steps >= 1

    # prepare logs dir
    safe_run_name = _safe_path_name(wandb_config['run_name'])

    # ...
    model.to(device)
    model.train()

    # init
    step_ = 0
    step = 0
    best_accuracy = 0
    best_state = _copy_state_dict(model)
    grad_flow_data = init_grad_flow_data(model)
    start_time = time.perf_counter()

    # tqdm
    use_tqdm = run_config['training']['view_train_tqdm']
    if use_tqdm:
        train_bar = tqdm(
            dataloaders['train'],
            desc=f'{desc_text}Train (batches)',
            unit='batch',
            # leave=False,
            leave=True,
            dynamic_ncols=True,
        )
        train_iterator = train_bar
    else:
        train_bar = None
        train_iterator = dataloaders['train']
    num_training_steps = len(train_iterator)

    # print(f"{num_training_steps = }")

    # dirs
    (train_logs_dir, val_logs_dir, test_logs_dir,
     best_models_dir, model_checkpoints_dir) = _get_save_subdirs(save_dir)

    should_break = False

    # training loop: steps
    for step, batch in enumerate(train_iterator):

        cur_step_logs = {}

        # train
        train_logs, train_logs_per_n_facts, _ = train_or_evaluate_batch_stats(
            model, batch, device, train_config,
            ignored_token=IGNORED_TOKEN, train=True,
            loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler,
            use_amp=use_amp, scaler=scaler,
        )

        # save logs
        if (logs_save_steps is not None) and (step % logs_save_steps == 0):
            _save_logs(logs=train_logs, logs_dir=train_logs_dir, name=safe_run_name)

        # evaluate
        if (step > 0) and (step % train_config['evaluate_any_num_steps'] == 0) or (step == num_training_steps-1):

            best_accuracy, best_state, should_break, cur_step_logs = _perform_evaluation_step(
                model, dataloaders, run_config, loss_fn, step, desc_text,
                best_accuracy, best_state, cur_step_logs, grad_flow_data, save_dir,
            )

            # save logs
            _save_logs(logs=train_logs, logs_dir=val_logs_dir, name=safe_run_name)

        # update logs and log to wandb
        if wandb_config['activate'] and (step % accum_steps == 0):
            cur_step_logs.update({'train': train_logs})
            if wandb_config['log_per_n_results']:
                cur_step_logs.update({'train.per_n': train_logs_per_n_facts})
            if step != len(train_iterator)-1:
                wandb_log(cur_step_logs, step=step)

        # update tqdm
        if train_bar is not None:
            train_bar.set_postfix({
                'loss': f"{train_logs['loss']:.3e}",
                'accuracy': f"{train_logs['accuracy']:.3f}",
                'best_accuracy': f"{best_accuracy:.3f}",
                'grad_norm': f"{train_logs['grad_norm']:.2e}",
            })

        if should_break:
            break

    # end run

    # close tqdm
    if train_bar is not None:
        train_bar.close()

    # finish wandb run
    if wandb_run is not None:
        wandb_run.summary["status"] = "finished"
        wandb_run.finish()

    # save model state (best and last)
    final_state = _copy_state_dict(model)
    _save_model_checkpoint(best_state, best_models_dir, safe_run_name)
    _save_model_checkpoint(final_state, model_checkpoints_dir, safe_run_name)

    # evaluate best model on test set
    model.load_state_dict(best_state)
    model.eval()
    test_logs, test_results_per_n_facts = evaluate_split(
        model, dataloaders['test'], run_config,
        loss_fn=loss_fn, step=step_, desc_text=desc_text,
    )
    end_time = time.perf_counter()
    time_elapsed = end_time - start_time

    # save results
    # final_results = {
    #     'test_logs': test_logs,
    #     'test_results_per_n_facts': test_results_per_n_facts,
    #     'duration': {
    #         'steps': step,
    #         'time': time_elapsed,
    #     }
    # }
    _save_logs(logs=test_logs, logs_dir=test_logs_dir, name=safe_run_name)
    final_results = copy.copy(test_logs)

    # return best model
    best_model = model

    wandb.finish()

    return best_model, final_results


def _perform_evaluation_step(
        model, dataloaders, run_config, loss_fn, step, desc_text,
        best_accuracy, best_state, cur_step_logs, grad_flow_data, save_dir,
):
    should_break = False

    # configs
    train_config = run_config['training']
    wandb_config = run_config['wandb']
    accuracy_threshold = train_config['threshold_accuracy']
    safe_run_name = _safe_path_name(wandb_config['run_name'])

    # dirs
    _, _, _, best_models_dir, model_checkpoints_dir = _get_save_subdirs(save_dir)

    # evaluate
    model.eval()
    val_logs, val_logs_per_n_facts = evaluate_split(
        model, dataloaders['val'], run_config,
        loss_fn=loss_fn, step=step, desc_text=desc_text,
    )
    val_accuracy = val_logs['accuracy']
    model.train()

    # log
    if wandb_config['activate']:

        if wandb_config['log_val']:
            cur_step_logs.update({'val': val_logs})

        if wandb_config['log_per_n_results']:
            cur_step_logs.update({'val.per_n': val_logs_per_n_facts})

        if wandb_config['log_grad_flow']:
            update_grad_flow_data(model, grad_flow_data)
            train_grad_flow_logs = format_grad_flow_logs(model, grad_flow_data)
            cur_step_logs.update({'grad_flow': train_grad_flow_logs})

    # save best model
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        best_state = _copy_state_dict(model)
        _save_model_checkpoint(best_state, best_models_dir, name=safe_run_name)

    # save model checkpoint
    if train_config['save_model_at_evaluation']:
        state = _copy_state_dict(model)
        _save_model_checkpoint(state, model_checkpoints_dir, name=safe_run_name)

    # stop training if exceeded threshold
    if val_accuracy > accuracy_threshold:
        # print(f"Accuracy {val_accuracy:.4f} exceeded threshold = {accuracy_threshold:.4f}, stopping train")  # TODO
        model.train()
        should_break = True

    return best_accuracy, best_state, should_break, cur_step_logs


def _construct_scheduler_from_config(optimizer, scheduler_config):
    """
    Piecewise-linear LR *multiplier* schedule (ratios only):
      - Warmup:  0 → 1 over `num_warmup_steps`
      - Decay:   1 → decay_factor over `num_decay_steps`
      - Flat:    decay_factor thereafter

    Applies the same ratio to all param groups (multiplies each group's base lr).

    Expected keys in `scheduler_config`:
      - decay_factor (float in [0,1], default 1.0)
      - num_warmup_steps (int >= 0, default 0)
      - num_decay_steps  (int >= 0, default 0)
    """
    if not scheduler_config:
        return None

    n_warm  = int(scheduler_config.get("num_warmup_steps", 0) or 0)
    n_decay = int(scheduler_config.get("num_decay_steps", 0) or 0)
    decay_factor = float(scheduler_config.get("decay_factor", 1.0))

    if n_warm < 0 or n_decay < 0:
        raise ValueError("num_warmup_steps and num_decay_steps must be >= 0.")
    if not (0.0 <= decay_factor <= 1.0):
        raise ValueError("decay_factor must be in [0, 1].")

    # Nothing to schedule (always 1.0)
    if n_warm == 0 and n_decay == 0 and abs(decay_factor - 1.0) < 1e-12:
        return None

    def lr_lambda(step: int) -> float:
        # Phase 1: warmup 0 → 1
        if n_warm > 0 and step < n_warm:
            return step / float(max(1, n_warm))

        # Phase 2: linear decay 1 → decay_factor
        if n_decay > 0 and step < n_warm + n_decay:
            t = step - n_warm
            frac = t / float(max(1, n_decay))  # in [0,1]
            return (1.0 - frac) * 1.0 + frac * decay_factor

        # Phase 3: flat at decay_factor
        return decay_factor

    return LambdaLR(optimizer, lr_lambda)


@dataclasses.dataclass
class Accumulator:
    correct: int = 0
    total: int = 0


def evaluate_split(
        model, dataloader, config,
        loss_fn=None, ignored_token=IGNORED_TOKEN, step=0,
        verbose=False,
        desc_text="",
):
    use_tqdm = config['training']['view_evaluation_tqdm']

    if verbose:
        print(f"\nEvaluating over '{config['dataset']}'...")

    device = config['runtime']['device']

    model.to(device)
    model.eval()

    per_n_facts_accumulators = {}
    overall_accumulator = Accumulator(total=0, correct=0)
    loss_sum = 0
    total_n_query_tokens = 0

    if use_tqdm:
        val_bar = tqdm(
            dataloader,
            desc=f'{desc_text}Step {step} | Evaluation (batches)',
            unit='batch',
            leave=False,
        )
        val_iterator = val_bar
    else:
        val_iterator = dataloader
        val_bar = None

    for i, batch in enumerate(val_iterator):

        n_facts = batch.N_facts

        batch_logs, batch_logs_per_n_facts, batch_accumulator = train_or_evaluate_batch_stats(
            model, batch, device, config,
            ignored_token=ignored_token, train=False, loss_fn=loss_fn,
        )
        n_query_tokens = batch_accumulator.total
        n_correct_preds = batch_accumulator.correct

        if n_facts not in per_n_facts_accumulators.keys():
            per_n_facts_accumulators[n_facts] = Accumulator(total=0, correct=0)

        # update accumulators
        total_n_query_tokens += n_query_tokens
        per_n_facts_accumulators[n_facts].total += n_query_tokens
        per_n_facts_accumulators[n_facts].correct += n_correct_preds
        overall_accumulator.total += n_query_tokens
        overall_accumulator.correct += n_correct_preds

        # update sums
        if loss_fn is not None:
            batch_mean_loss = batch_logs['loss']
            loss_sum += batch_mean_loss * n_query_tokens

    # mean accumulators
    overall_accuracy = float(overall_accumulator.correct / overall_accumulator.total)
    overall_error_rate = 1 - overall_accuracy
    mean_n_correct_per_n_facts = {str(n): float(x.correct / (x.total / n)) for n, x in per_n_facts_accumulators.items()}
    accuracy_per_n_facts = {str(n): float(x.correct / x.total) for n, x in per_n_facts_accumulators.items()}
    error_rate_per_n_facts = {str(n): (1 - x) for n, x in accuracy_per_n_facts.items()}

    loss_mean = loss_sum / total_n_query_tokens

    if val_bar is not None:
        val_bar.set_postfix(accuracy=f"{overall_accuracy:.3f}, loss={loss_mean:.3e}")
        val_bar.close()

    if verbose:
        print(
            f"Evaluation completed! "
            f"\n{overall_accuracy = }"
            f"\n{accuracy_per_n_facts = }"
            f"\n{mean_n_correct_per_n_facts = }\n",
        )

    split_logs = {
        'accuracy': overall_accuracy,
        'error_rate': overall_error_rate,
        'loss': loss_mean,
    }
    split_logs_per_n_facts = {
        'accuracy': accuracy_per_n_facts,
        'error_rate': error_rate_per_n_facts,
        'mean_n_correct': mean_n_correct_per_n_facts,
    }

    return split_logs, split_logs_per_n_facts


def train_or_evaluate_batch_stats(
        model, batch: MqarBatch, device, train_config, ignored_token=IGNORED_TOKEN,
        train=False, loss_fn=None,
        optimizer=None, scheduler=None, scaler=None,
        use_amp=False,
):
    batch_logs = {}
    batch_logs_per_n_facts = {}

    x_ids = batch.x_ids.to(device)
    y_true_ids = batch.y_true_ids.to(device)

    # gradient accumulation setup
    accum_steps = int(train_config.get('grad_accum_steps', 1) or 1)
    ga_counter = int(train_config.get('_ga_counter', 0))

    # set train/eval mode + zero grads
    if train:
        model.train()
        # zero grads only at the start of an accumulation cycle
        if ga_counter % accum_steps == 0:
            optimizer.zero_grad()
    else:
        model.eval()

    # get the right autocast + device contexts (only active on CUDA)
    autocast_ctx, device_ctx = get_available_context(device, use_amp=use_amp)

    with torch.set_grad_enabled(train):
        # both move-to-device & mixed-precision
        with device_ctx, autocast_ctx:
            y_pred_logits = model(input_ids=x_ids).logits  # [B, L, V]
            y_pred_logits = y_pred_logits.transpose(2, 1)  # [B, V, L]

    # Compute cross-entropy in float32 for stability
    if loss_fn is not None:
        raw_loss = loss_fn(y_pred_logits.float(), y_true_ids)
        loss = raw_loss
        if train and accum_steps > 1:
            # average the loss so accumulated grads match full-batch grads
            loss = loss / accum_steps

    # training-only: backward + step + scheduler
    if train:

        assert loss_fn is not None

        clip_grad_max_norm = train_config['clip_grad_max_norm']
        do_step = ((ga_counter + 1) % accum_steps == 0)

        if use_amp:
            assert scaler is not None, "Pass in a GradScaler when training"
            scaler.scale(loss).backward()
            if do_step:
                if train_config['clip_grad']:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_max_norm)
                scaler.step(optimizer)  # this internally calls optimizer.step()
                scaler.update()

        else:
            loss.backward()
            if do_step:
                if train_config['clip_grad']:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_max_norm)
                optimizer.step()  # step the optimizer manually when not using amp

        # advance the learning rate schedule *after* the optimizer has stepped
        if do_step and scheduler is not None:
            scheduler.step()

        # update accumulation counter (wraps to 0 at boundary)
        train_config['_ga_counter'] = (ga_counter + 1) % accum_steps


    # collect predictions
    y_pred_ids = y_pred_logits.argmax(dim=1)
    y_correct = y_pred_ids.eq(y_true_ids)
    if ignored_token is not None:
        y_correct = y_correct.masked_select(y_true_ids.ne(ignored_token)).detach().cpu()

    # accuracy = (y_pred_ids == y_true_ids).float().mean().item()
    accuracy = y_correct.float().mean()

    batch_accumulator = Accumulator(
        correct=y_correct.sum().numpy(),
        total=y_correct.numel(),
    )

    # to numpy
    def _serializable(x):
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        if isinstance(x, np.ndarray):
            x = x.tolist()
        return x

    # logs
    batch_logs['accuracy'] = _serializable(accuracy)
    batch_logs['error_rate'] = _serializable(1 - accuracy)
    if loss_fn is not None:
        batch_logs['loss'] = _serializable(raw_loss)
    if train:
        batch_logs['learning_rate'] = _serializable(optimizer.param_groups[0]['lr'])
        batch_logs['batch_size'] = _serializable(batch.size)
        batch_logs['grad_norm'] = _serializable(calculate_grad_norm(model))

    # # per n_facts logs  # currently removed, TODO
    # n = batch.N_facts
    # batch_logs_per_n_facts[f'accuracy[{n}]'] = batch_accuracy
    # batch_logs_per_n_facts[f'loss[{n}]'] = batch_loss
    # if train:
    #     batch_logs_per_n_facts[f'grad_norm[{n}]'] = grad_norm

    return batch_logs, batch_logs_per_n_facts, batch_accumulator
