import logging
from argparse import Namespace

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, LearningRateFinder
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split

logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

from lm_understanding.baselines.baseline import Baseline
from lm_understanding.question_template import TemplateModelBehavior


class SingleLayerNN(pl.LightningModule):
    def __init__(self, n_features):
        super().__init__()
        self.layer = nn.Linear(n_features, 1)
        self.lr = 1e-2

    def forward(self, x):
        return torch.sigmoid(self.layer(x))

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = nn.BCELoss()(self(x), y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


class LinearRegression(Baseline):
    def __init__(self, model_behavior: TemplateModelBehavior, baseline_config: Namespace, *args, **kwargs):
        super().__init__(model_behavior, baseline_config, *args, **kwargs)
        self.model = None

    def train(self):
        assert self.encoded_questions is not None
        x = self.encoded_questions['train']
        labels = np.array(self.model_behavior.answers('train'))
        inputs = torch.tensor(x, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)
        dataset = TensorDataset(inputs, labels)
        batch_size = 128
        train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        # train, val = random_split(dataset, [0.9, 0.1])
        # train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True)
        # val_dataloader = DataLoader(val, batch_size=batch_size, shuffle=False)
        model = SingleLayerNN(x.shape[1])
        early_stop_callback = EarlyStopping(
            monitor='val_loss',  # Monitor validation loss
            patience=5,  # Number of epochs to wait before stopping training if performance does not improve
            min_delta=0.001,  # Minimum change in monitored quantity to qualify as improvement
            verbose=False  # If true, prints a message whenever early stopping conditions are met
        )
        trainer = pl.Trainer(max_epochs=100, enable_progress_bar=False, enable_checkpointing=False, enable_model_summary=False, logger=False, callbacks=[])
        trainer.fit(model, train_dataloader)
        self.model = model

    @torch.inference_mode()
    def predict(self, split: str, question_idx: int) -> float:
        assert self.model is not None
        assert self.encoded_questions is not None
        embedding = torch.Tensor(self.encoded_questions[split][question_idx])
        return self.model(embedding).item()

