import argparse
import os

import yaml
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
import numpy as np

# import datasets
# import models
from utils import get_band_interval, make_band_coords


# CAVE
s_min = 395
s_max = 705
num_band = 31

root_dir = "../"

for num_band in [8, 16, 31]:
    s_intervals = get_band_interval(s_min = s_min, s_max = s_max, num_band = num_band)
    # s_coords = make_band_coords(s_intervals)

    outfile = f"{root_dir}/dataset_preprocess/dataset/CAVE/CAVEdata/waves_{num_band}.npy"
    print(num_band, s_intervals, outfile)
    np.save(outfile, s_intervals)
    
    
    
# Pavia Centra
s_min = 430
s_max = 860
num_band = 102

root_dir = "../"

s_intervals = get_band_interval_by_mid_wave(s_min = s_min, s_max = s_max, num_band = num_band)
# s_min = np.min(s_intervals)
# s_max = np.max(s_intervals)
outfile = f"{root_dir}/dataset_preprocess/dataset/Pavia_Centre/waves_{num_band}.npy"
print(num_band, s_intervals, outfile)
np.save(outfile, s_intervals)
    
# for num_band in [25, 51, 102]:
#     s_intervals = get_band_interval(s_min = s_min, s_max = s_max, num_band = num_band)

#     outfile = f"{root_dir}/dataset_preprocess/dataset/Pavia_Centre/waves_{num_band}.npy"
#     print(num_band, s_intervals, outfile)
#     np.save(outfile, s_intervals)