#%%
import torch
from torchinfo import summary
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.path import AffineProbPath
from nns import DiT, InitializerNet, train
from diffusers.optimization import get_cosine_schedule_with_warmup
from data_utils import get_dataset
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='torch')

from Configs import TrainingConfig

config = TrainingConfig()

if torch.cuda.is_available():
    device = 'cuda'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')

torch.manual_seed(0)
train_model = False  # set to False to load a pre-trained model

#%% get the dataset
battery_dataset = 'matr_1_Q'
dataset = get_dataset(battery_dataset)

n_batches = 9
train_batch_size = dataset.train_data.feature.shape[0] // n_batches
eval_batch_size = dataset.test_data.feature.shape[0] // n_batches

train_dataloader = torch.utils.data.DataLoader(
    dataset.train_data,
    batch_size=train_batch_size,
    shuffle=True,
)

test_dataloader = torch.utils.data.DataLoader(
    dataset.test_data,
    batch_size=eval_batch_size,
    shuffle=True,
)

#%% model and initializer
input_shape = (1,) + dataset.train_data.feature.cpu().numpy().shape[1:]
mask_size = 2 if battery_dataset == 'mix_20_Q' else 10

model = DiT(input_dim=config.sequence_length, 
            input_shape= input_shape, 
            num_blocks=config.num_blocks,
            num_channels=config.channels,
            class_dropout_prob=config.class_dropout_prob,
            mask_size=mask_size).to(device) 

initializer = InitializerNet(input_dim=config.sequence_length, mask_size=mask_size).to(device)

t = torch.rand(dataset.train_data.label.shape[0]).to(device)
print(summary(model, input_data=(dataset.train_data.label, t, dataset.train_data.feature), device=device))
#%% train the model
path = AffineProbPath(scheduler=CondOTScheduler())

optim = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) 
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optim,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

if train_model:
    train(model, initializer, train_dataloader, test_dataloader, path, optim, lr_scheduler, device, config, validation=False)
    torch.save(model.state_dict(), f'./workspaces/{config.output_dir}_{battery_dataset}.pth')
else:
    model.load_state_dict(torch.load(f'./workspaces/trained_models_FM/{config.output_dir}_{battery_dataset}.pth', weights_only=True))
    model.eval()

