import os
# import utils
import torch
import random
import math
import argparse
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from models.Model import UNet
from dataloader import myDataLoader
from flow.sinkhorn_flow import Sinkhorn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
from Scheduler import GradualWarmupScheduler
from torchvision.utils import save_image

def main(opts):
    device = torch.device(opts.device)
    device_count = torch.cuda.device_count()
    print("cuda.device_count",device_count)
    # device_ids = list(range(device_count))
    # device_ids = [0, 1, 2, 3]
    setup_seed(opts.seed)

    net_model = UNet(T=opts.T, ch=opts.channel, ch_mult=opts.channel_mult, attn=opts.attn,
                     num_res_blocks=opts.num_res_blocks, dropout=opts.dropout).to(device)
    # net_model = torch.nn.DataParallel(net_model, device_ids=device_ids)
    optimizer = torch.optim.AdamW(
        net_model.parameters(), lr=opts.lr, weight_decay=1e-4)                 
    # cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer=optimizer, T_max= opts.epoch, eta_min=0, last_epoch=-1)
    # warmUpScheduler = GradualWarmupScheduler(
        # optimizer=optimizer, multiplier=opts.multiplier, warm_epoch=opts.epoch, after_scheduler=cosineScheduler)
    global_step = 0 

    folder_list = os.listdir(opts.data_root)
    for i in range(17):
        folder_data = i % 17
        folder = folder_list[folder_data]
        print(f'now folder is {folder}')
    # start training
    
    for e in range(opts.epoch):
        # load data
        # folder_data = e % opts
        folder = folder_list[folder_data]
        print('Start Data Loading')
        vector = np.load(os.path.join(opts.data_root, folder, 'vector_fields.npy'))
        print('Load vector ', folder)
        noise = np.load(os.path.join(opts.data_root, folder, 'x_noise.npy'))
        print('Load noise ', folder)
        vector_fields = torch.from_numpy(vector).cpu()
        x_noise = torch.from_numpy(noise).cpu()
        print('vector_fields.shape', vector_fields.shape)
        print('x_noise.shape', x_noise.shape)
        vector_fields = vector_fields.view(-1, 3072)
        x_noise = x_noise.view(-1, 3072)
        print('Successfully Loaded!')
        print('start trainning', "epoch: ", e)
        # index = torch.randperm(x_noise.size(0))
        # v = vector_fields[index]
        # x = x_noise[index]
        # t = index % opts.T
        for num_index in range(opts.num_minibatch):
            index = torch.randperm(opts.minibatch_size)
            v = vector_fields[num_index * opts.minibatch_size : (num_index + 1) * opts.minibatch_size] 
            v = vector_fields[index]
            x = x_noise[num_index * opts.minibatch_size : (num_index + 1) * opts.minibatch_size]
            x = x_noise[index]
            t = index % opts.T
            for step in range(opts.train_step):
                optimizer.zero_grad()
                with torch.no_grad():
                    x_step = x[step * opts.train_size : (step + 1) * opts.train_size].view(-1, 3, 32, 32).to(device)
                    v_step = v[step * opts.train_size : (step + 1) * opts.train_size].view(-1, 3, 32, 32).to(device)
                    t_step = t[step * opts.train_size : (step + 1) * opts.train_size].to(device)
                loss = F.mse_loss(net_model(x_step, t_step), v_step, reduction = 'sum' )
                loss = loss / 1000
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                net_model.parameters(), opts.grad_clip)
                optimizer.step()
                optimizer.zero_grad()
                
                if global_step % opts.intervals == 0 or step == opts.train_step -1 :
                    torch.save(net_model.state_dict(), os.path.join(opts.save_weight_dir, 'ckpt_' + str(global_step) + "_.pt"))
                    print('Successfully trained', global_step, loss)
                global_step += 1
        # warmUpScheduler.step()

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--algorithm', type = str, default = 'SD')
    parser.add_argument('--dataset', type = str, default='CIFAR10')
    parser.add_argument('--epoch', type = int, default = 24)
    parser.add_argument('--batch_size', type = int, default = 64)
    parser.add_argument('--train_size', type = int, default = 64)
    parser.add_argument('--device', type = str, default = 'cuda:0')
    parser.add_argument('--intervals', type = int, default = 30000)
    parser.add_argument('--T', type = int, default = 100)
    parser.add_argument('--train_step', type = int, default = 200)
    parser.add_argument('--num_minibatch', type = int, default = 390)
    parser.add_argument('--minibatch_size', type = int, default = 12800)
    parser.add_argument('--seed', type = int, default = 0)
    #SD
    parser.add_argument('--SD_lr', type = int, default = 0.03)   
    # use Index sd_lr * exp((t - T) / (T / 4))  
    parser.add_argument('--blur', type = float, default = 0.05)
    parser.add_argument('--scaling', type = float, default = 0.8)
    parser.add_argument('--backend', type = str, default = 'online')
    #U-net
    parser.add_argument('--lr', type = int, default = 1e-5)
    parser.add_argument('--channel', type = int, default = 128)
    parser.add_argument('--channel_mult', type = list, default = [1, 2, 3, 4])
    parser.add_argument('--attn', type = list, default = [2])
    parser.add_argument('--num_res_blocks', type = int, default = 2)
    parser.add_argument('--dropout', type = int, default = 0.15)
    #
    parser.add_argument('--multiplier', type = int, default = 2)
    # 
    parser.add_argument('--grad_clip', type = float, default = 1.)
    #
    parser.add_argument('--save_weight_dir', type = str, default = './Checkpoints/15_data')
    # parser.add_argument('--tese_load_weight', type = str, default = 'ckpt_199_.pt')
    #
    parser.add_argument('--sampled_dir', type = str, default = './results/')
    parser.add_argument('--NoisyImgName', type = str, default = 'NoisyImgs')
    parser.add_argument('--sampledImgName', type = str, default = 'SampledImgs')
    parser.add_argument('--nrow', type = int, default = 8)
    parser.add_argument('--eval_interval', type = int, default = 20)
    parser.add_argument('--test_dir', type = str, default = './testdata/')
    parser.add_argument('--data_root', type = str, default = './data_root_hdd')

    opts = parser.parse_args()
    main(opts)
