import os
import argparse
import re

import numpy as np
import pandas as pd

from tqdm import tqdm

import torch as torch
from torch.utils.data import DataLoader

import lightning.pytorch as L

from ..data.synthetic import CoordinatedBinaryClassificationDataModule, BinaryClassificationDataModule
from ..model.lr import LatentConformalLogisticRegression, LogisticRegressionModule

class LogisticRegression:
    def __init__(self, max_epochs, **hparams):
        self.max_epochs = max_epochs
        self.model = LogisticRegressionModule(**hparams)

    def fit(self, datamodule):
        logger = L.loggers.TensorBoardLogger('log/lightning/bootstrap', name="lr")
        trainer = L.Trainer(max_epochs=self.max_epochs, logger=logger)
        trainer.fit(self.model, datamodule=datamodule)
        return self
    
    def predict(self, datamodule):
        trainer = L.Trainer(logger=False)
        preds = trainer.predict(self.model, datamodule)
        return torch.cat([z for z in preds], dim=0)

    
def run(
        name:str,
        dgp: str,
        seed: int,
        n: int, 
        n_bootstrap: int,
        hidden_dim: int,
        dropout: float, 
        lr: float,
        batch_size: int,
        epochs: int, 
        overwrite: bool
    ):
    csv_file_path = f'results/data/bootstrap/{name}.csv'

    if os.path.exists(csv_file_path) and not overwrite:
        print(f"File {csv_file_path} already exists. Use --overwrite to overwrite.")
        return

    rng = np.random.default_rng(seed)
    L.seed_everything(seed)

    datamodule = BinaryClassificationDataModule(dgp, rng, n, batch_size)

    dfs = []
    for b in tqdm(range(n_bootstrap)):
        datamodule.resample_traindataset(rng)

        model = LogisticRegression(
            max_epochs=epochs,
            input_dim=datamodule.train_dataset.x.shape[1],
            hidden_dim=hidden_dim,
            dropout=dropout,
            lr=lr
            )
        
        model_file_path = f"results/model/bootstrap/{name}_{b:04d}.pth"
        if os.path.exists(model_file_path) and not overwrite:
            model.model.load_state_dict(torch.load(model_file_path))
        else:
            model = model.fit(datamodule)
            torch.save(model.model.state_dict(), model_file_path)

        data = {
            "b": b,
            "y4": datamodule.predict_dataset.y.numpy(),
            "z4": datamodule.predict_dataset.z.numpy(),
            "z4_hat": model.predict(datamodule).numpy(),
        }

        df = pd.DataFrame(data)
        dfs.append(df)

    df = pd.concat(dfs, axis=0)
    df.to_csv(csv_file_path, index=False)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--dgp', type=str, default='sin', choices=['sin'])
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--n', type=int, default=1000)
    parser.add_argument('--n_bootstrap', type=int, default=100)
    parser.add_argument('--hidden_dim', type=int, default=64)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--overwrite', action='store_true')
    
    args = parser.parse_args()
    args = dict(vars(args))

    name = "_".join(f"{k}-{re.sub(r'[^\w]', '_', str(v))}" for k, v in args.items() if k not in ["dgp", "overwrite"])

    run(name, **args)
