import os
import argparse
import torch
import math
import random
import numpy as np
from models.Model import UNet
import warnings
from scipy import linalg
from torchvision.utils import save_image
import torchvision.models as models


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

@torch.no_grad()
def eval(opts):
    setup_seed(opts.seed)
    device = torch.device(opts.device)
    model = UNet(T=100, ch=opts.channel, ch_mult=opts.channel_mult, attn=opts.attn,
                    num_res_blocks=opts.num_res_blocks, dropout=0.).to(device)
    # model = torch.nn.DataParallel(model)
    ckpt = torch.load(os.path.join(
        opts.save_weight_dir, opts.test_load_weight), map_location=device)
    model.load_state_dict(ckpt)
    print("model load weight done.")

    model.eval()
    # noisyImage = torch.rand(size = [opts.batch_size, 3, 32, 32], device=device)
    # noisyImage = torch.stack(noise).view(-1, 3, 32, 32).to(device)

    # if not os.path.exists(os.path.join(opts.sampled_dir)):
    #     os.mkdir(os.path.join(opts.sampled_dir))
    # save_image(noisyImage, os.path.join(opts.sampled_dir, opts.sampledNoisyImgName), nrow = opts.nrow)
    num_sampling_rounds = opts.eval_num_samples // opts.batch_size + 1
    # generate_samples = []
    if not os.path.exists(os.path.join(opts.generated_dir)):
        os.mkdir(os.path.join(opts.generated_dir))
    # i = 0
    num_picture = 0
    for round in range(num_sampling_rounds):
        noisyImage = torch.rand(size = [opts.batch_size, 3, 32, 32], device=device)
        for time_step in range(opts.T):
            # time = time_step
            # if time_step >= 99 :
            #     time_step = 99
            t = noisyImage.new_ones(opts.batch_size, dtype = torch.long) * time_step
            vector = model(noisyImage, t) # [batch_size, 3, 32, 32]
            # lr = self.opts.SD_lr * math.exp((step - self.opts.T) / (self.opts.T / 4))
            # lr = opts.SD_lr * math.exp((time_step - opts.T) / (opts.T / 4))
            lr = opts.SD_lr
            noisyImage = noisyImage - lr * vector
            # i = i + 1
            # if i == 2:
            #     break
        noisyImage = noisyImage.cpu().detach()
        for i in range(len(noisyImage)):
            if num_picture < opts.eval_num_samples :
                save_image(noisyImage[i], os.path.join(opts.generated_dir, '%s_%d.png'%(opts.generatedImgName, (round * opts.batch_size + i))))
                num_picture += 1
        if num_picture + 1 == 50000:
            break
        # generate_samples.append(noisyImage.cpu().detach().numpy())
        # if i == 2:
        #     break
        # if time % opts.eval_interval == 0 or (time + 1) == opts.T:
        #     #  saveimage = torch.clamp(noisyImage * 0.5 +0.5, 0, 1)
        #      save_image(noisyImage, os.path.join(opts.sampled_dir,  '%s_%d.png'%(opts.sampledImgName, time)), nrow = opts.nrow)
        #      print("out", time)
    print(f'Successfully Generate {num_picture} Picture!')
    # if not os.path.exists(os.path.join(opts.fid_data_root)):
    #     os.mkdir(os.path.join(opts.fid_data_root))
    # np.savez_compressed(os.path.join(opts.fid_data_root, 'generate_samples.npy'), np.array(generate_samples))
    os.system(f'python -m pytorch_fid ./fid_data_root/real ./{opts.generated_dir}')

def main(opts):
    eval(opts)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type = int, default = 43)
    parser.add_argument('--batch_size', type = int, default = 64)
    parser.add_argument('--T', type = int, default = 100)
    parser.add_argument('--device', type = str, default = 'cuda:3')
    parser.add_argument('--fid_data_root', type = str, default = './fid_data_root')
    parser.add_argument('--SD_lr', type = float, default = 0.05)  
    parser.add_argument('--save_weight_dir', type = str, default = './Checkpoints/24_64_index')
    parser.add_argument('--test_load_weight', type = str, default = 'ckpt_100000_.pt')
    parser.add_argument('--generated_dir', type = str, default = './fid_data_root/24_64/generated_100000')
    parser.add_argument('--generatedImgName', type = str, default = 'sinkhorn')
    parser.add_argument('--eval_num_samples', type = int, default = 50000)
    #U-net
    parser.add_argument('--lr', type = float, default = 1e-5)
    parser.add_argument('--channel', type = int, default = 128)
    parser.add_argument('--channel_mult', type = list, default = [1, 2, 3, 4])
    parser.add_argument('--attn', type = list, default = [2])
    parser.add_argument('--num_res_blocks', type = int, default = 2)
    parser.add_argument('--dropout', type = float, default = 0.15)
    opts = parser.parse_args()
    main(opts)