import os
import argparse
import re

import numpy as np
import pandas as pd

import torch as torch
from torch.utils.data import DataLoader

import lightning.pytorch as L

from ..data.synthetic import CoordinatedBinaryClassificationDataModule
from ..model.lr import LatentConformalLogisticRegression

def run(
        name:str, dgp: str, seed: int, n: int, 
        hidden_dim: int, dropout: float, 
        alpha: float, beta:float, 
        oracle_z:bool,
        lr: float, batch_size: int, epochs: int, 
        overwrite: bool
        ):
    rng = np.random.default_rng(seed)
    L.seed_everything(seed)

    datamodule = CoordinatedBinaryClassificationDataModule(dgp, rng, n, batch_size)

    model = LatentConformalLogisticRegression(datamodule.train_datasets["data1"].x.shape[1], hidden_dim, dropout, alpha, beta, oracle_z, lr)
    logger = L.loggers.TensorBoardLogger('log/lightning/synthetic/train', name=dgp)
    trainer = L.Trainer(max_epochs=epochs, logger=logger)

    model_file_path = f'results/model/{dgp}/{name}.pth'

    if not os.path.exists(model_file_path) or overwrite:
        trainer.fit(model, datamodule)
        torch.save(model.state_dict(), model_file_path)
    else:
        model.load_state_dict(torch.load(model_file_path))

    csv_file_path = f'results/data/{dgp}/{name}.csv'

    if not os.path.exists(csv_file_path) or overwrite:
        preds_list = trainer.predict(model, datamodule)
        data = {}

        for i, k in enumerate(datamodule.train_datasets):
            data[f"y{i+1}"] = datamodule.train_datasets[k].y.numpy()
            data[f"z{i+1}"] = datamodule.train_datasets[k].z.numpy()

        for preds in preds_list:
            for k in preds[0].keys():
                data[k] = torch.cat([x[k] for x in preds], dim=0).numpy()

        df = pd.DataFrame(data)
        df.to_csv(csv_file_path, index=False)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--dgp', type=str, default='lin', choices=['lin', 'tanh', 'sin'])
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--n', type=int, default=1000)
    parser.add_argument('--hidden_dim', type=int, default=64)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--alpha', type=float, default=1.0)
    parser.add_argument('--beta', type=float, default=1.0)
    parser.add_argument('--oracle_z', action='store_true')
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--overwrite', action='store_true')
    
    args = parser.parse_args()
    args = dict(vars(args))

    name = "_".join(f"{k}-{re.sub(r'[^\w]', '_', str(v))}" for k, v in args.items() if k not in ["dgp", "overwrite"])

    run(name, **args)
