#%%
import torch
from nns_trf import TrF, train
from diffusers.optimization import get_cosine_schedule_with_warmup
from data_utils import get_dataset
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='torch')

from Configs_trf import TrainingConfig

config = TrainingConfig()

if torch.cuda.is_available():
    device = 'cuda'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')

train_model = False  # set to False to load a pre-trained model

battery_dataset = 'matr_1_Q'
dataset = get_dataset(battery_dataset)

n_batches = 9
train_batch_size = dataset.train_data.feature.shape[0] // n_batches
eval_batch_size = dataset.test_data.feature.shape[0] // n_batches

train_dataloader = torch.utils.data.DataLoader(
    dataset.train_data,
    batch_size=train_batch_size,
    shuffle=True,
)

test_dataloader = torch.utils.data.DataLoader(
    dataset.test_data,
    batch_size=eval_batch_size,
    shuffle=True,
)

input_shape = (1,) + dataset.train_data.feature.cpu().numpy().shape[1:]
mask_size = 2 if battery_dataset == 'mix_20_Q' else 10

for seed in range(10):
    torch.manual_seed(seed)

    model = TrF(input_dim=config.sequence_length, 
                input_shape= input_shape, #config.input_channels,
                num_blocks=config.num_blocks,
                num_channels=config.channels,
                class_dropout_prob=config.class_dropout_prob,
                mask_size=mask_size).to(device)  # g_φ
    
    optim = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) 
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optim,
        num_warmup_steps=config.lr_warmup_steps,
        num_training_steps=(len(train_dataloader) * config.num_epochs),
    )

    if train_model:
        train(model, train_dataloader, test_dataloader, optim, lr_scheduler, device, config, validation=True)
        torch.save(model.state_dict(), f'./workspaces/{config.output_dir}_{seed}_{battery_dataset}.pth')
    else:
        model.load_state_dict(torch.load(f'./workspaces/trained_models_TrF/{config.output_dir}_{seed}_{battery_dataset}.pth', weights_only=True))
        model.eval()
    


