import os

import torch
from pytorch_lightning import LightningModule
import numpy as np
import torchmetrics
from tqdm import tqdm

from .torch_utils import archs, get_optimizer, get_loss, get_scheduler


class TorchLightningModel(LightningModule):

    def __init__(self, n_features, n_classes, epochs, loss_name='ce', n_channels=None,
        learning_rate=1e-4, momentum=0.0, weight_decay=0.0, batch_size=256, optimizer='sgd',
        architecture='arch_001', train_type=None, grad_clip: float=0.):

        print(f'lr: {learning_rate}, opt: {optimizer}, loss: {loss_name}, '
              f'arch: {architecture}, batch_size: {batch_size}, '
              f'momentum: {momentum}, weight_decay: {weight_decay}, grad_clip: {grad_clip}, '
              f'train_type: {train_type}')
        super().__init__()
        self.n_features = n_features
        self.n_classes = n_classes
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.architecture = architecture
        self.loss_name = loss_name
        self.grad_clip = grad_clip
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.epochs = epochs

        arch_fn = getattr(archs, self.architecture)
        arch_params = dict(n_features=n_features, n_classes=self.n_classes, n_channels=n_channels)
        self.model = arch_fn(**arch_params)

        self.base_loss_fn = get_loss(self.loss_name, reduction="none")

        self.trn_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_nb):
        x, y = batch
        logits = self(x)
        loss = self.base_loss_fn(logits, y)
        loss = loss.mean()

        self.trn_acc(logits, y)
        self.log('trn_acc', self.trn_acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('trn_base_loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.base_loss_fn(logits, y).mean()
        self.val_acc(logits, y)
        self.log('tst_acc', self.val_acc, on_step=True, on_epoch=True, prog_bar=True)
        self.log('tst_base_loss', loss, on_step=True, on_epoch=True, prog_bar=True)

    def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None):
        return self(batch[0])

    def configure_optimizers(self):
        opt = get_optimizer(self.model, self.optimizer, self.learning_rate, self.momentum, self.weight_decay)
        scheduler = get_scheduler(opt, self.epochs)
        return [opt], [scheduler]
        #return {
        #    "optimizer": opt,
        #    "lr_scheduler": {
        #        "scheduler": scheduler,
        #        "interval": "step",
        #    }
        #}
            
    def predict_ds(self, ds, batch_size=32, num_workers=12, device='cuda'):
        self.model.eval()
        loader = torch.utils.data.DataLoader(ds,
            batch_size=batch_size, shuffle=False, num_workers=num_workers)
        ret = []
        for x in tqdm(loader, desc="[predict_ds]"):
            pred = self.model(x[0].to(device)).argmax(1).cpu().numpy()
            ret.append(pred)
        del loader
        return np.concatenate(ret)
