import time
import torch
from trainkit.timing import print_time_taken
from trainkit.logging import TrainLogger
import nn
from nn import cola_nn
import losses
import optimizers
from data import get_loaders
from model import set_seed
from model import train_epoch
from model import eval_model
from torchinfo import summary
from math import prod
import wandb


def train_supervised(
    model="MLP",
    data_dir=None,
    output_dir="./logs",
    dataset="mnist",
    train_subset=1.0,
    seed=21,
    epochs=10,
    eval_freq=5,
    batch_size=128,
    optimizer="adamw",
    lr=1e-3,
    wd=0,
    mom=0.9,
    scheduler="cosine",
    sch_steps=0,
    loss="cross_entropy",
    depth=2,
    width=100,
    residual=True,
    layer_norm=True,
    dropout=0.0,
    strategy="low_rank_strategy",
    rank_frac=1,
    use_wandb=False,
    run_name=None,
):

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    config = locals()
    tic = time.time()
    set_seed(seed)

    Log = TrainLogger(config, output_dir=output_dir)

    train_loader, _, test_loader = get_loaders(dataset, batch_size, root=data_dir, train_subset=train_subset)
    for data in train_loader:
        break
    in_shape = (1, *data[0].shape[1:])
    model = getattr(nn, model)(dim_in=prod(in_shape), dim_out=train_loader.dataset.n_classes, depth=depth, width=width,
                               residual=residual, layer_norm=layer_norm, dropout=dropout)
    print('Base model:')
    stats = summary(model, in_shape)
    base_params = stats.trainable_params
    base_flops = stats.total_mult_adds
    strategy = getattr(cola_nn, strategy)
    cola_nn.colafy(model, strategy=strategy, rank=int(rank_frac * width + 0.5))
    print('CoLA model:')
    stats = summary(model, in_shape)
    cola_params = stats.trainable_params
    cola_flops = stats.total_mult_adds

    info = {'base_params': base_params, 'base_flops': base_flops, 'cola_params': cola_params, 'cola_flops': cola_flops}
    config.update(info)

    Log.logger.info(f'Base params: {base_params} | CoLA params: {cola_params}')
    Log.logger.info(f'Base flops: {base_flops} | CoLA flops: {cola_flops}')

    optimizer = getattr(optimizers, optimizer)(model.parameters(), lr=lr, wd=wd, mom=mom)
    scheduler = getattr(optimizers, scheduler)(optimizer, init_lr=lr, num_steps=sch_steps)
    criterion = getattr(losses, loss)()

    if use_wandb:
        wandb_run = wandb.init(project='struct', config=config, save_code=True, name=run_name)
    else:
        wandb_run = None

    for epoch in range(epochs):
        train_info = train_epoch(model, train_loader, optimizer, criterion, device)
        Log.metrics.append((epoch, train_info['loss'], 'loss'))
        Log.metrics.append((epoch, train_info['acc'], 'TrA'))
        Log.logger.info(f"E: {epoch} | L: {train_info['loss']:.1f} | TrA: {train_info['acc']:.1f}")
        if (epoch % eval_freq == 0) or (epoch == epochs - 1):
            test_info = eval_model(model, test_loader, device=device)
            Log.logger.info(f"E: {epoch} | TeA: {test_info['acc']:.1f}")
            Log.metrics.append((epoch, test_info['acc'], 'TeA'))
            metrics = {'train_acc': train_info['acc'], 'test_acc': test_info['acc']}
            if wandb_run:
                wandb_run.log(metrics)

    toc = time.time()
    print_time_taken(toc - tic, logger=Log.logger)
    Log.finalize_logging()
    if wandb_run:
        wandb_run.finish()


def entrypoint(**kwargs):
    train_supervised(**kwargs)


if __name__ == '__main__':
    import fire
    fire.Fire(entrypoint)
