import os
import argparse
import re

import numpy as np
import pandas as pd

import torch as torch

import lightning.pytorch as L

from ..data.llm import LLMPreferenceDataModule
from ..model.llm import LLMLatentConformalBradleyTerry

def run(
        name:str,
        dgp: str,
        seed: int,
        n: int, 
        hidden_dim: int,
        dropout: float, 
        alpha: float,
        beta:float,
        oracle_z: bool,
        pretrained:bool,
        lr: float,
        batch_size: int,
        epochs: int, 
        overwrite: bool,
        scaleweight: bool,
        ):
    rng = np.random.default_rng(seed)
    L.seed_everything(seed)

    csv_file_path = f'results/data/llm/{name}.csv'
    if os.path.exists(csv_file_path) and not overwrite:
        print(f"File {csv_file_path} already exists. Use --overwrite to overwrite.")
        return

    datamodule = LLMPreferenceDataModule(rng, n, batch_size)
    
    initial_weight = datamodule.train_datasets["data1"].weight.detach() if pretrained else None

    if pretrained and scaleweight:
        initial_weight *= 0.25

    model = LLMLatentConformalBradleyTerry(
        input_dim=datamodule.train_datasets["data1"].x1.shape[1],
        alpha=alpha,
        beta=beta,
        lr=lr, 
        oracle_z=oracle_z,
        initial_weight=initial_weight)
    logger = L.loggers.TensorBoardLogger('log/lightning/llm/train', name=dgp)
    trainer = L.Trainer(max_epochs=epochs, logger=logger)

    post_fix = "_pretrained" if pretrained else ""
    model_file_path = f'results/model/llm/{name}{post_fix}.pth'

    if not os.path.exists(model_file_path) or overwrite:
        trainer.fit(model, datamodule)
        torch.save(model.state_dict(), model_file_path)
    else:
        model.load_state_dict(torch.load(model_file_path))

    preds_list = trainer.predict(model, datamodule)
    data = {}

    for i, k in enumerate(datamodule.train_datasets):
        data[f"i{i + 1}"] = datamodule.train_datasets[k].i.numpy()
        data[f"y{i+1}"] = datamodule.train_datasets[k].y.numpy()
        data[f"z{i + 1}"] = datamodule.train_datasets[k].z.numpy()

    for preds in preds_list:
        for k in preds[0].keys():
            data[k] = torch.cat([x[k] for x in preds], dim=0).numpy()

    df = pd.DataFrame(data)
    df.to_csv(csv_file_path, index=False)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--dgp', type=str, default='tanh', choices=['lin', 'tanh', 'sin'])
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--n', type=int, default=1000)
    parser.add_argument('--hidden_dim', type=int, default=64)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--alpha', type=float, default=1.0)
    parser.add_argument('--beta', type=float, default=1.0)
    parser.add_argument('--oracle_z', action='store_true')
    parser.add_argument('--pretrained', action='store_true')
    parser.add_argument('--scaleweight', action='store_true')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--overwrite', action='store_true')
    
    args = parser.parse_args()
    args = dict(vars(args))

    name = "_".join(f"{k}-{re.sub(r'[^\w]', '_', str(v))}" for k, v in args.items() if k not in ["dgp", "overwrite", "scaleweight"])

    if args["scaleweight"]:
        name += "_scaleweight"

    run(name, **args)
