import os
import numpy as np
import argparse

import torch
from tqdm import tqdm
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from flow.sinkhorn_flow import Sinkhorn

def main(opts):
    main_dir = os.path.join('./datasets', opts.dataset)
    dataset = CIFAR10(
        root=main_dir, train=True, download=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]))
    dataloader = DataLoader(
        dataset, batch_size=opts.batch_size, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
    
    device = torch.device(opts.device)

    # i = 0
    # vector_fields_all = []
    # x_noise_all = []
    for e in range(opts.epoch):
        # vector_fields_all = []
        # x_noise_all = []
        if not os.path.exists(os.path.join(opts.data_root, f'epoch_{e}')):
            os.mkdir(os.path.join(opts.data_root, f'epoch_{e}'))
        index = 0
        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for images, _ in tqdmDataLoader:
                x_0 = images.view(len(images), -1).to(device) 
                #计算梯度数据
                SD_batch = Sinkhorn(opts, x_0)
                SD_batch.forward()
                vector_fields, x_noise = SD_batch.get_state()  # record_sinkdiv, record_support [time * batch_size, 3*32*32]
                # vector_fields_all.append(vector_fields.cpu().detach().numpy())
                # x_noise_all.append(x_noise.cpu().detach().numpy())
                np.save(os.path.join(opts.data_root, f'epoch_{e}/vector_fields_{index}.npy'), np.array(vector_fields.cpu().detach().numpy()))
                np.save(os.path.join(opts.data_root, f'epoch_{e}/x_noise_{index}.npy'), np.array(x_noise.cpu().detach().numpy()))
                index += 1
                # i += 1
                # if i == 1:
                #     break
        print(f'Successfully generated vector_fields from random x_noise, epoch: {e}')
    
    # vector_fields_all = torch.tensor([item for item in vector_fields_all])
    # x_noise_all = torch.tensor([item for item in x_noise_all])
    # vector_fields_all = torch.cat(vector_fields_all).to(device)
    # x_noise_all = torch.cat(x_noise_all).to(device)
    # print(vector_fields_all.shape, x_noise_all.shape)

    # if not os.path.exists(os.path.join(opts.data_root)):
    #     os.mkdir(os.path.join(opts.data_root))
    # # vector_fields_all_compressed = np.compress(np.array(vector_fields_all))
    # # x_noise_all_compressed = np.compress(np.array(x_noise_all))
    # # np.savez_compressed(os.path.join(opts.data_root, 'vector_fields.npz'), np.array(vector_fields_all))
    # # np.savez_compressed(os.path.join(opts.data_root, 'x_noise.npz'), np.array(x_noise_all))
    # np.save(os.path.join(opts.data_root, 'vector_fields.npy'), np.array(vector_fields_all))
    # np.save(os.path.join(opts.data_root, 'x_noise.npy'), np.array(x_noise_all))

    import sys 
    print('Successfully generated vector_fields from random x_noise')
    sys.exit(0)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type = str, default='CIFAR10')
    parser.add_argument('--batch_size', type = int, default = 128)
    parser.add_argument('--epoch', type = int, default = 20)
    parser.add_argument('--T', type = int, default = 100)
    parser.add_argument('--device', type = str, default = 'cuda:0')
    parser.add_argument('--data_root', type = str, default = './data_root_hdd/20_data')
    

    parser.add_argument('--SD_lr', type = float, default = 0.16)   
    parser.add_argument('--blur', type = float, default = 0.1)
    parser.add_argument('--scaling', type = float, default = 0.95)
    parser.add_argument('--backend', type = str, default = 'online')

    opts = parser.parse_args()
    main(opts)
