import pytorch_lightning as pl
import torch
import torch.utils.data as data
from easydict import EasyDict
from torch.utils.data import DataLoader
import argparse
from net.dataset import SynteticAnomalyDetection
from net.Lightning_model import LightningModel

def search(config):
    # Define data
    num_samples = 1000
    n = 20
    dataset = SynteticAnomalyDetection(n=n,  num_samples=3 * num_samples)
    train, val, test = data.random_split(dataset, [num_samples, num_samples,num_samples])
    train_dl = DataLoader(train, batch_size=2, shuffle=True)
    val_dl = DataLoader(val, batch_size=200)
    test_dl = DataLoader(test, batch_size=200)
    # Train
    model_dir = '../data/models_new'
    callbacks = [pl.callbacks.ModelCheckpoint(dirpath=model_dir,
                                              filename='{epoch}-f{val_acc:.5f}',
                                              save_top_k=1,
                                              monitor=f'val_acc',
                                              save_last=True, mode='max')]
    ckpt = callbacks[0]
    config = EasyDict(config)
    model = LightningModel(config=config)
    trainer = pl.Trainer(max_epochs=10, accelerator='gpu', callbacks=callbacks,check_val_every_n_epoch = 2)
    training = True
    if training:
        trainer.fit(model, train_dl, val_dl)
        checkpoint = torch.load(ckpt.best_model_path)
        model.load_state_dict(checkpoint['state_dict'])
    test_score = trainer.test(model, dataloaders=test_dl, verbose=False)[0]['test_acc']
    print(f"Test accuracy: {test_score}")
    train_score = trainer.test(model, dataloaders=train_dl, verbose=False)[0]['test_acc']
    print(f"Train accuracy: {train_score}")
    return test_score


parser = argparse.ArgumentParser(description="Process dataset with a specified radius.")
# Add dataset_name as a string argument (positional)
parser.add_argument('--model_type', type=str, help='Name of the model type')
parser = parser.parse_args()
siamese_type = parser.model_type
assert siamese_type in ['DSS','SchurNet','Siamese']
print(siamese_type)
config = EasyDict(lr=0.005, wd=1e-5, decay_factor=0.7, 
                  dims=[1]+[256 for _ in range(4)]+ [1], siamese=siamese_type)
acc = search(config=config)
print(acc)
