import math
import os
import numpy as np
from tqdm import tqdm
import time

# os.environ["CUDA_VISIBLE_DEVICES"] = " 0 "

import torch
import torch.nn as nn
import torchvision
from torch.optim.adam import Adam
from torch.utils.data.dataloader import DataLoader
from torch.optim import lr_scheduler
from torch.nn import functional
from torch.autograd import Variable
import itertools
import sys
import argparse

from dataset import HEVCDataSet, UVGDataSet, crop, merge
from net import VideoCompressor
from src.zoo.image import model_architectures as architectures
from compressai.zoo import cheng2020_attn, cheng2020_anchor

gpu_num = torch.cuda.device_count()

lambda_quality_map = {256: 3,
                      512: 4,
                      1024: 5,
                      2048: 6}


def activate_grad(module):
    for p in module.parameters():
        p.requires_grad = True


def close_grad(module):
    for p in module.parameters():
        p.requires_grad = False


def cal_rd_cost(distortion: torch.Tensor, bpp: torch.Tensor, lambda_weight):
    rd_cost = lambda_weight * distortion + bpp
    return rd_cost


def cal_bpp(likelihood: torch.Tensor, num_pixels: int):
    bpp = torch.log(likelihood).sum() / (-math.log(2) * num_pixels)
    return bpp


def cal_bits(likelihood: torch.Tensor):
    bits = torch.log(likelihood).sum() / (-math.log(2))
    return bits


def cal_distoration(A: torch.Tensor, B:torch.Tensor):
    dis = nn.MSELoss()
    return dis(A, B)


def cal_psnr(distortion: torch.Tensor):
    psnr = -10 * torch.log10(distortion)
    return psnr


def Var(x):
    return Variable(x.cuda())


def test_new_cropped(net, test_dataset, factor, overlap, write_strame_flag=False, calrealbits=False):
    # w: motion + delta, y: I / P res + hyper + delta
    # y[0] -> w[1] -> y[1] -> .... -> w[i] -> y[i] -> ... -> w[K], y[K]
    # perform AISTAT 13 HSAVI, iter**K, intractable

    # y[0] -> SGA, w[1], y[1] - w[k], y[K] -> AUN
    # optimize y[0] w.r.t RD[0:K]

    # initialize w[1] using y[0]

    # w[1] -> SGA, y[1] - w[k], y[k] -> AUN
    # optimize w[1] w.r.t RD[1:K]

    # initialize y[1] using w[1]
    # optimize y[1] w.r.t RD[1:K]
    test_loader = DataLoader(dataset=test_dataset, shuffle=False, num_workers=1, batch_size=1, pin_memory=True)
    sumbpp, sumbpp_mv_y, sumbpp_mv_z = 0, 0, 0
    sumbpp_res_y, sumbpp_res_z, sumbpp_real = 0, 0, 0
    sumpsnr, sumpsnr_pre = 0, 0
    eval_step = 0
    gop_num = 0
    avg_loss = torch.zeros(size=[1, ])
    dir = 1
    close_grad(I_codec)
    close_grad(net)
    I_codec.eval()
    net.eval()
    ############################ iterative tuning #############################
    for batch_idx, input in enumerate(test_loader):
        if batch_idx % 10 == 0:
            print("[info] testing : %d/%d" % (batch_idx, len(test_loader)))
        input_images = input[0].squeeze(0)
        seqlen = input_images.size()[0]
        cropped_len = len(crop(torch.unsqueeze(input_images[0, :, :, :], 0), factor=factor, overlap=overlap)[0])
        print("number of blocks:", cropped_len)
        B, C, H, W = torch.unsqueeze(input_images[0, :, :, :], 0).shape[0], \
                     torch.unsqueeze(input_images[0, :, :, :], 0).shape[1], \
                     torch.unsqueeze(input_images[0, :, :, :], 0).shape[2], \
                     torch.unsqueeze(input_images[0, :, :, :], 0).shape[3]
        cropped_blocks = []
        for m in range(seqlen):
            cur_frame = torch.unsqueeze(input_images[m, :, :, :], 0)
            cropped_images, _, _, _, _ = crop(cur_frame, factor=factor, overlap=overlap)
            cropped_blocks.append(cropped_images)
        total_bits_stack, rec_img_stack = [], []
        ############################ initial testing ##########################
        total_rd_cost = 0
        for i in range(seqlen):
            cur_frame = Var(torch.unsqueeze(input_images[i, :, :, :].cuda(), 0))
            b, h, w = cur_frame.shape[0], cur_frame.shape[2], cur_frame.shape[3]
            num_pixels = b * h * w
            if i == 0:
                with torch.no_grad():
                    arr = I_codec([cur_frame, "test_for_first", "testing"])
                I_rec = arr['x_hat']
                I_likelihood_y, I_likelihood_z = arr["likelihoods"]['y'], arr["likelihoods"]['z']
                ref_image = I_rec.detach().clone()
                y_bpp = cal_bpp(likelihood=I_likelihood_y, num_pixels=num_pixels).cpu().detach().numpy()
                z_bpp = cal_bpp(likelihood=I_likelihood_z, num_pixels=num_pixels).cpu().detach().numpy()
                psnr = cal_psnr(distortion=cal_distoration(I_rec, cur_frame)).cpu().detach().numpy()
                bpp = y_bpp + z_bpp
                distortion = cal_distoration(cur_frame, I_rec)
                rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                print("\n------------------------------------ GOP {0} --------------------------------------".format(
                    batch_idx + 1))
                print("I frame:  ", "bpp:", bpp, "\t", "psnr:", psnr, "\t", "rd_cost:", rd_cost.cpu().detach().numpy())
            else:
                with torch.no_grad():
                    clipped_recon_image, _, _, _, bpp_feature, bpp_z, bpp_mv, bpp, _, _, _, _, _, _, _, _, _ = \
                        net(referframe=ref_image, input_image=cur_frame, iter=0, total_iter=0, stage="test_for_first",
                            mode="test")
                ref_image = clipped_recon_image

                distortion = cal_distoration(cur_frame, clipped_recon_image)
                rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                psnr = 10 * (torch.log(1 * 1 / distortion) / np.log(10)).cpu().detach().numpy()
                mv_bpp = bpp_mv
                res_y_bpp = bpp_feature
                res_z_bpp = bpp_z
                bpp = bpp
                print("P{0} frame: ".format(i), "mv_bpp:", mv_bpp.cpu().detach().numpy(), "\t", "res_bpp:",
                      res_y_bpp.cpu().detach().numpy(), "\t", "res_hyper_bpp:", res_z_bpp.cpu().detach().numpy(), "\t",
                      "bpp:", bpp.cpu().detach().numpy(), "\t", "psnr", psnr, "\t", "\t", "rd_cost",
                      rd_cost.cpu().detach().numpy())
            total_rd_cost += rd_cost
        print("total_rd_cost_initial:", total_rd_cost.cpu().detach().numpy())

        for n in range(cropped_len):
            ############# initial testing end, tuning start #######################
            I_y_stack, I_z_stack, delta_I_stack = [], [], []
            mv_feature_stack, delta_mv_stack = [], []
            feature_stack, z_stack, delta_res_stack = [], [], []
            sub_iter_I, sub_iter_w, sub_iter_y, sub_lr = 2000, 400, 400, 1e-3
            for i in range(seqlen):
                cur_frame = Var(cropped_blocks[i][n].cuda())
                b, h, w = cur_frame.shape[0], cur_frame.shape[2], cur_frame.shape[3]
                num_pixels = b * h * w
                if i == 0:
                    ###################### initialize param #######################
                    with torch.no_grad():
                        # use deterministic rounding here
                        arr = I_codec([cur_frame, "test_for_first", "testing"])
                    I_y_stack.append(arr['y'].detach().clone().requires_grad_(True))
                    I_z_stack.append(arr['z'].detach().clone().requires_grad_(True))
                    delta_I_stack.append(torch.tensor(arr["delta"]).clone().detach().requires_grad_(True))
                    ##################### actual tuning ###########################
                    cur_params = I_y_stack + I_z_stack + delta_I_stack
                    optimizer_I = Adam(params=cur_params, lr=sub_lr)
                    for sub_it in range(sub_iter_I):
                        optimizer_I.zero_grad()
                        for sub_i in range(seqlen):
                            sub_cur_frame = Var(cropped_blocks[sub_i][n].cuda())
                            if sub_i == 0:
                                result = I_codec(
                                    [sub_cur_frame, "finetune", "training", sub_it, sub_iter_I, I_y_stack[0], I_z_stack[0],
                                     delta_I_stack[0]])
                                recon_image = result['x_hat']
                                I_likelihood_y, I_likelihood_z = result["likelihoods"]['y'], result["likelihoods"]['z']
                                y_bpp, z_bpp = cal_bpp(likelihood=I_likelihood_y, num_pixels=num_pixels), cal_bpp(
                                    likelihood=I_likelihood_z, num_pixels=num_pixels)
                                bpp = y_bpp + z_bpp
                            else:
                                clipped_recon_image, _, _, _, _, _, _, bpp, _, _, _, _, _, _, _, _, _ = \
                                    net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                        stage="test_for_first", mode="training")
                                recon_image = clipped_recon_image
                                bpp = bpp
                            ref_image = recon_image
                            distortion = cal_distoration(sub_cur_frame, recon_image)
                            rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test) / seqlen
                            rd_cost.backward(retain_graph=True)
                        optimizer_I.step()
                        optimizer_I.zero_grad()
                    for param in cur_params:
                        param.requires_grad = False
                    ###################### sub testing start ######################
                    total_rd_cost = 0
                    for sub_i in range(seqlen):
                        sub_cur_frame = Var(cropped_blocks[sub_i][n].cuda())
                        if sub_i == 0:
                            with torch.no_grad():
                                result = I_codec(
                                    [sub_cur_frame, "finetune", "test", 0, 0, I_y_stack[0], I_z_stack[0], delta_I_stack[0]])
                            I_likelihood_y, I_likelihood_z = result["likelihoods"]['y'], result["likelihoods"]['z']
                            ref_image = result['x_hat'].detach().clone()
                            y_bpp, z_bpp = cal_bpp(likelihood=I_likelihood_y,
                                                   num_pixels=num_pixels).cpu().detach().numpy(), cal_bpp(
                                likelihood=I_likelihood_z, num_pixels=num_pixels).cpu().detach().numpy()
                            I_rec = result["x_hat"]
                            psnr = cal_psnr(distortion=cal_distoration(I_rec, sub_cur_frame)).cpu().detach().numpy()
                            bpp = y_bpp + z_bpp
                            distortion = cal_distoration(sub_cur_frame, I_rec)
                            rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                            print(
                                "\n------------------------------------ GOP {0} --------------------------------------".format(
                                    batch_idx + 1))
                            print("I frame block {0} :  ".format(n), "bpp:", bpp, "\t", "psnr:", psnr, "\t", "rd_cost:",
                                  rd_cost.cpu().detach().numpy())
                        else:
                            with torch.no_grad():
                                clipped_recon_image, _, _, _, bpp_feature, bpp_z, bpp_mv, bpp, _, _, _, _, _, _, _, _, _ = \
                                    net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                        stage="test_for_first", mode="test")
                            ref_image = clipped_recon_image
                            distortion = cal_distoration(sub_cur_frame, clipped_recon_image)
                            rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                            psnr = 10 * (torch.log(1 * 1 / distortion) / np.log(10)).cpu().detach().numpy()
                            mv_bpp = bpp_mv
                            res_y_bpp = bpp_feature
                            res_z_bpp = bpp_z
                            bpp = bpp
                            print("P{0} frame block {1} : ".format(sub_i, n), "mv_bpp:",
                                  mv_bpp.cpu().detach().numpy(), "\t", "res_bpp:", res_y_bpp.cpu().detach().numpy(), "\t",
                                  "res_hyper_bpp:", res_z_bpp.cpu().detach().numpy(), "\t", "bpp:",
                                  bpp.cpu().detach().numpy(), "\t", "psnr", psnr, "\t", "\t", "rd_cost",
                                  rd_cost.cpu().detach().numpy())
                        total_rd_cost += rd_cost
                    print("total_rd_cost_middle {0}:".format(i), total_rd_cost.cpu().detach().numpy())
                else:
                    ################## initialize motion vectors of frame i #######
                    for sub_i in range(i + 1):
                        sub_cur_frame = Var(cropped_blocks[sub_i][n].cuda())
                        if sub_i == 0:
                            with torch.no_grad():
                                result = I_codec(
                                    [sub_cur_frame, "finetune", "test", 0, 0, I_y_stack[0], I_z_stack[0], delta_I_stack[0]])
                            ref_image = result['x_hat'].detach().clone()
                        else:
                            with torch.no_grad():
                                if sub_i < i:
                                    clipped_recon_image, _, _, _, _, _, _, _, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="finetune", mode="test", \
                                            feature=feature_stack[sub_i - 1][0], z=z_stack[sub_i - 1][0],
                                            delta=delta_res_stack[sub_i - 1][0], \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0],
                                            calrealbits=calrealbits)
                                elif sub_i == i:
                                    clipped_recon_image, _, _, _, _, _, _, _, mvfeature, _, _, _, delta_mv, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="test_for_first", mode="test")
                                    mv_feature_stack.append([mvfeature.detach().clone().requires_grad_(True)])
                                    delta_mv_stack.append([delta_mv.detach().clone().requires_grad_(False)])
                                else:
                                    assert (0)
                            ref_image = clipped_recon_image
                    ##################### actual tuning mv of frame i #############
                    cur_params = mv_feature_stack[i - 1]
                    optimizer_mv = Adam(params=cur_params, lr=sub_lr)
                    for sub_it in range(sub_iter_w):
                        optimizer_mv.zero_grad()
                        for sub_i in range(seqlen):
                            sub_cur_frame = Var(cropped_blocks[sub_i][n].cuda())
                            if sub_i == 0:
                                I_y_for_optim, I_z_for_optim, delta_I_for_optim = I_y_stack[0], I_z_stack[0], delta_I_stack[
                                    0]
                                result = I_codec([sub_cur_frame, "finetune", "test", 0, 0, I_y_for_optim, I_z_for_optim,
                                                  delta_I_for_optim])
                                recon_image = result['x_hat']
                                I_likelihood_y, I_likelihood_z = result["likelihoods"]['y'], result["likelihoods"]['z']
                                y_bpp, z_bpp = cal_bpp(likelihood=I_likelihood_y, num_pixels=num_pixels), cal_bpp(
                                    likelihood=I_likelihood_z, num_pixels=num_pixels)
                                bpp = y_bpp + z_bpp
                            else:
                                if sub_i < i:
                                    clipped_recon_image, _, _, _, _, _, _, bpp, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="finetune", mode="test", \
                                            feature=feature_stack[sub_i - 1][0], z=z_stack[sub_i - 1][0],
                                            delta=delta_res_stack[sub_i - 1][0], \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0],
                                            calrealbits=calrealbits)
                                elif sub_i == i:
                                    clipped_recon_image, _, _, _, _, _, _, bpp, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=sub_it, total_iter=sub_iter_w,
                                            stage="finetune_flow", mode="training", \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0])
                                else:
                                    clipped_recon_image, _, _, _, _, _, _, bpp, _, _, _, _, _, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="test_for_first", mode="training")
                                recon_image = clipped_recon_image
                                bpp = bpp
                            ref_image = recon_image
                            if sub_i >= i:
                                distortion = cal_distoration(sub_cur_frame, recon_image)
                                rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test) / seqlen
                                rd_cost.backward(retain_graph=True)
                        optimizer_mv.step()
                        optimizer_mv.zero_grad()
                    for param in cur_params:
                        param.requires_grad = False
                    ###################### sub testing start ######################
                    total_rd_cost = 0
                    for sub_i in range(seqlen):
                        sub_cur_frame = Var(cropped_blocks[sub_i][n].cuda())
                        if sub_i == 0:
                            with torch.no_grad():
                                result = I_codec(
                                    [sub_cur_frame, "finetune", "test", 0, 0, I_y_stack[0], I_z_stack[0], delta_I_stack[0]])
                            I_likelihood_y, I_likelihood_z = result["likelihoods"]['y'], result["likelihoods"]['z']
                            I_rec = result['x_hat'].detach().clone()
                            ref_image = I_rec
                            y_bpp, z_bpp = cal_bpp(likelihood=I_likelihood_y,
                                                   num_pixels=num_pixels).cpu().detach().numpy(), cal_bpp(
                                likelihood=I_likelihood_z, num_pixels=num_pixels).cpu().detach().numpy()
                            psnr = cal_psnr(distortion=cal_distoration(I_rec, sub_cur_frame)).cpu().detach().numpy()
                            bpp = y_bpp + z_bpp
                            distortion = cal_distoration(sub_cur_frame, I_rec)
                            rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                            print(
                                "\n------------------------------------ GOP {0} --------------------------------------".format(
                                    batch_idx + 1))
                            print("I frame block {} :  ".format(n), "bpp:", bpp, "\t", "psnr:", psnr, "\t", "rd_cost:",
                                  rd_cost.cpu().detach().numpy())
                        else:
                            with torch.no_grad():
                                if sub_i < i:
                                    clipped_recon_image, _, _, _, bpp_feature, bpp_z, bpp_mv, bpp, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="finetune", mode="test", \
                                            feature=feature_stack[sub_i - 1][0], z=z_stack[sub_i - 1][0],
                                            delta=delta_res_stack[sub_i - 1][0], \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0],
                                            calrealbits=calrealbits)
                                elif sub_i == i:
                                    clipped_recon_image, _, _, _, bpp_feature, bpp_z, bpp_mv, bpp, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="finetune_flow", mode="test", \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0])
                                else:
                                    clipped_recon_image, _, _, _, bpp_feature, bpp_z, bpp_mv, bpp, _, _, _, _, _, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="test_for_first", mode="test")
                            ref_image = clipped_recon_image
                            distortion = cal_distoration(sub_cur_frame, clipped_recon_image)
                            rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                            psnr = 10 * (torch.log(1 * 1 / distortion) / np.log(10)).cpu().detach().numpy()
                            mv_bpp = bpp_mv
                            res_y_bpp = bpp_feature
                            res_z_bpp = bpp_z
                            bpp = bpp
                            print("P{0} frame block {1} : ".format(sub_i, n), "mv_bpp:",
                                  mv_bpp.cpu().detach().numpy(), "\t", "res_bpp:", res_y_bpp.cpu().detach().numpy(), "\t",
                                  "res_hyper_bpp:", res_z_bpp.cpu().detach().numpy(), "\t", "bpp:",
                                  bpp.cpu().detach().numpy(), "\t", "psnr", psnr, "\t", "\t", "rd_cost",
                                  rd_cost.cpu().detach().numpy())
                        total_rd_cost += rd_cost
                    print("total_rd_cost_middle_mv {0}:".format(i), total_rd_cost.cpu().detach().numpy())
                    ################## initialize residule of frame i #############
                    for sub_i in range(i + 1):
                        sub_cur_frame = Var(cropped_blocks[sub_i][n].cuda())
                        if sub_i == 0:
                            with torch.no_grad():
                                result = I_codec(
                                    [sub_cur_frame, "finetune", "test", 0, 0, I_y_stack[0], I_z_stack[0], delta_I_stack[0]])
                            ref_image = result['x_hat'].detach().clone()
                        else:
                            with torch.no_grad():
                                if sub_i < i:
                                    clipped_recon_image, _, _, _, _, _, _, _, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="finetune", mode="test", \
                                            feature=feature_stack[sub_i - 1][0], z=z_stack[sub_i - 1][0],
                                            delta=delta_res_stack[sub_i - 1][0], \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0],
                                            calrealbits=calrealbits)
                                elif sub_i == i:
                                    clipped_recon_image, _, _, _, _, _, _, _, _, feature, z, delta, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="test_for_stage1", mode="test", \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0])
                                    feature_stack.append([feature.detach().clone().requires_grad_(True)])
                                    z_stack.append([z.detach().clone().requires_grad_(True)])
                                    delta_res_stack.append([delta.detach().clone().requires_grad_(True)])
                                else:
                                    assert (0)
                            ref_image = clipped_recon_image
                    ##################### actual tuning res of frame i ############
                    cur_params = feature_stack[i - 1] + z_stack[i - 1] + delta_res_stack[i - 1]
                    optimizer_res = Adam(params=cur_params, lr=sub_lr)
                    for sub_it in range(sub_iter_y):
                        optimizer_res.zero_grad()
                        for sub_i in range(seqlen):
                            sub_cur_frame = Var(cropped_blocks[sub_i][n].cuda())
                            if sub_i == 0:
                                I_y_for_optim, I_z_for_optim, delta_I_for_optim = I_y_stack[0], I_z_stack[0], delta_I_stack[
                                    0]
                                result = I_codec([sub_cur_frame, "finetune", "test", 0, 0, I_y_for_optim, I_z_for_optim,
                                                  delta_I_for_optim])
                                recon_image = result['x_hat']
                                I_likelihood_y, I_likelihood_z = result["likelihoods"]['y'], result["likelihoods"]['z']
                                y_bpp, z_bpp = cal_bpp(likelihood=I_likelihood_y, num_pixels=num_pixels), cal_bpp(
                                    likelihood=I_likelihood_z, num_pixels=num_pixels)
                                bpp = y_bpp + z_bpp
                            else:
                                if sub_i < i:
                                    clipped_recon_image, _, _, _, _, _, _, bpp, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="finetune", mode="test", \
                                            feature=feature_stack[sub_i - 1][0], z=z_stack[sub_i - 1][0],
                                            delta=delta_res_stack[sub_i - 1][0], \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0],
                                            calrealbits=calrealbits)
                                elif sub_i == i:
                                    clipped_recon_image, _, _, _, _, _, _, bpp, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=sub_it, total_iter=sub_iter_y,
                                            stage="finetune", mode="training", \
                                            feature=feature_stack[sub_i - 1][0], z=z_stack[sub_i - 1][0],
                                            delta=delta_res_stack[sub_i - 1][0], \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0],
                                            calrealbits=calrealbits)
                                else:
                                    clipped_recon_image, _, _, _, _, _, _, bpp, _, _, _, _, _, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="test_for_first", mode="training")
                                recon_image = clipped_recon_image
                                bpp = bpp
                            ref_image = recon_image
                            if sub_i >= i:
                                distortion = cal_distoration(sub_cur_frame, recon_image)
                                rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test) / seqlen
                                rd_cost.backward(retain_graph=True)
                        optimizer_res.step()
                        optimizer_res.zero_grad()
                    for param in cur_params:
                        param.requires_grad = False
                    ###################### sub testing start ######################
                    total_rd_cost = 0
                    for sub_i in range(seqlen):
                        sub_cur_frame = Var(cropped_blocks[sub_i][n].cuda())
                        if sub_i == 0:
                            with torch.no_grad():
                                result = I_codec(
                                    [sub_cur_frame, "finetune", "test", 0, 0, I_y_stack[0], I_z_stack[0], delta_I_stack[0]])
                            I_likelihood_y, I_likelihood_z = result["likelihoods"]['y'], result["likelihoods"]['z']
                            I_rec = result['x_hat'].detach().clone()
                            ref_image = I_rec
                            y_bpp, z_bpp = cal_bpp(likelihood=I_likelihood_y,
                                                   num_pixels=num_pixels).cpu().detach().numpy(), cal_bpp(
                                likelihood=I_likelihood_z, num_pixels=num_pixels).cpu().detach().numpy()
                            psnr = cal_psnr(distortion=cal_distoration(I_rec, sub_cur_frame)).cpu().detach().numpy()
                            bpp = y_bpp + z_bpp
                            distortion = cal_distoration(sub_cur_frame, I_rec)
                            rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                            print(
                                "\n------------------------------------ GOP {0} --------------------------------------".format(
                                    batch_idx + 1))
                            print("I frame block {} :  ".format(n), "bpp:", bpp, "\t", "psnr:", psnr, "\t", "rd_cost:",
                                  rd_cost.cpu().detach().numpy())
                        else:
                            with torch.no_grad():
                                if sub_i <= i:
                                    clipped_recon_image, _, _, _, bpp_feature, bpp_z, bpp_mv, bpp, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="finetune", mode="test", \
                                            feature=feature_stack[sub_i - 1][0], z=z_stack[sub_i - 1][0],
                                            delta=delta_res_stack[sub_i - 1][0], \
                                            mvfeature=mv_feature_stack[sub_i - 1][0], delta_mv=delta_mv_stack[sub_i - 1][0],
                                            calrealbits=calrealbits)
                                else:
                                    clipped_recon_image, _, _, _, bpp_feature, bpp_z, bpp_mv, bpp, _, _, _, _, _, _, _, _, _ = \
                                        net(referframe=ref_image, input_image=sub_cur_frame, iter=0, total_iter=0,
                                            stage="test_for_first", mode="test")
                            ref_image = clipped_recon_image
                            distortion = cal_distoration(sub_cur_frame, clipped_recon_image)
                            rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                            psnr = 10 * (torch.log(1 * 1 / distortion) / np.log(10)).cpu().detach().numpy()
                            mv_bpp = bpp_mv
                            res_y_bpp = bpp_feature
                            res_z_bpp = bpp_z
                            bpp = bpp
                            print("P{0} frame block {1} : ".format(sub_i, n), "mv_bpp:",
                                  mv_bpp.cpu().detach().numpy(), "\t", "res_bpp:", res_y_bpp.cpu().detach().numpy(), "\t",
                                  "res_hyper_bpp:", res_z_bpp.cpu().detach().numpy(), "\t", "bpp:",
                                  bpp.cpu().detach().numpy(), "\t", "psnr", psnr, "\t", "\t", "rd_cost",
                                  rd_cost.cpu().detach().numpy())
                        total_rd_cost += rd_cost
                    print("total_rd_cost_middle_res {0}:".format(i), total_rd_cost.cpu().detach().numpy())
            ####################### formal test ###################################
            rec_block_stack = []
            bits_stack = []
            total_rd_cost_latent_finetune = 0
            for i in range(seqlen):
                cur_frame = Var(cropped_blocks[i][n].cuda())
                b, h, w = cur_frame.shape[0], cur_frame.shape[2], cur_frame.shape[3]
                num_pixels = b * h * w
                if i == 0:
                    with torch.no_grad():
                        arr = I_codec(
                            [cur_frame, "test_for_final", "testing", 0, 0, I_y_stack[0], I_z_stack[0], delta_I_stack[0]])
                    I_rec = arr['x_hat']
                    I_likelihood_y, I_likelihood_z = arr["likelihoods"]['y'], arr["likelihoods"]['z']
                    ref_image = I_rec.clone().detach()
                    y_bpp = cal_bpp(likelihood=I_likelihood_y, num_pixels=num_pixels)
                    z_bpp = cal_bpp(likelihood=I_likelihood_z, num_pixels=num_pixels)
                    psnr = cal_psnr(distortion=cal_distoration(I_rec, cur_frame)).cpu().detach().numpy()
                    bpp = y_bpp + z_bpp
                    rec_block_stack.append(I_rec)

                    distortion = cal_distoration(cur_frame, I_rec)
                    rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                    print("\n", "I frame block {} :  ".format(n), "bpp:", bpp.cpu().detach().numpy(), "\t", "psnr:",
                          psnr, "\t", "rd_cost:", rd_cost.cpu().detach().numpy(), "\t", "delta",
                          delta_I_stack[0].cpu().detach().numpy())
                    gop_num += 1
                    sumbpp_real += bpp
                else:
                    with torch.no_grad():
                        clipped_recon_image, _, _, _, bpp_feature, bpp_z, bpp_mv, bpp, real_bpp_feature, real_bpp_z, real_bpp_mv, real_bpp = \
                            net(referframe=ref_image, input_image=cur_frame, iter=0, total_iter=0, stage="finetune",
                                mode="test", \
                                feature=feature_stack[i - 1][0], z=z_stack[i - 1][0], delta=delta_res_stack[i - 1][0],
                                mvfeature=mv_feature_stack[i - 1][0], delta_mv=delta_mv_stack[i - 1][0],
                                calrealbits=calrealbits)
                    ref_image = clipped_recon_image
                    distortion = cal_distoration(cur_frame, clipped_recon_image)
                    rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
                    psnr = 10 * (torch.log(1 * 1 / distortion) / np.log(10)).cpu().detach().numpy()
                    mv_bpp = bpp_mv
                    res_y_bpp = bpp_feature
                    res_z_bpp = bpp_z
                    rec_block_stack.append(clipped_recon_image)
                    print("\n", "P{0} frame block {1} after: ".format(i, n), "mv_bpp:", mv_bpp.cpu().detach().numpy(), "\t",
                          "res_bpp:", res_y_bpp.cpu().detach().numpy(),
                          "\t", "res_hyper_bpp:", res_z_bpp.cpu().detach().numpy(), "\t", "bpp:",
                          bpp.cpu().detach().numpy(), "\t", "real_bpp_mv:", real_bpp_mv.cpu().detach().numpy(), "\t",
                          "real_bpp_res:", real_bpp_feature.cpu().detach().numpy(), "real_bpp_z:",
                          real_bpp_z.cpu().detach().numpy(), "\t", "real_bpp:", real_bpp.cpu().detach().numpy(),
                          "\t", "psnr", psnr, "\t", "\t", "rd_cost",
                          rd_cost.cpu().detach().numpy(), "\t", "delta_res:",
                          delta_res_stack[i - 1][0].cpu().detach().numpy(),
                          "\t", "delta_mv", delta_mv_stack[i - 1][0].cpu().detach().numpy())
                    sumbpp_mv_y += mv_bpp
                    sumbpp_res_y += res_y_bpp
                    sumbpp_res_z += res_z_bpp
                    sumbpp_real += real_bpp
                bits_stack.append(bpp.cpu().detach().numpy() * num_pixels)
                total_rd_cost_latent_finetune += rd_cost
            dir += 1
            print("total_rd_cost_after_res_finetune:", total_rd_cost_latent_finetune.cpu().detach().numpy())

            total_bits_stack.append(bits_stack)
            rec_img_stack.append(rec_block_stack)
        reversed_rec_img_stack = []
        reversed_bits_stack = []
        for x in range(seqlen):
            reversed_block_stack = []
            reversed_block_bits_stack = []
            for y in range(cropped_len):
                reversed_block_stack.append(rec_img_stack[y][x])
                reversed_block_bits_stack.append(total_bits_stack[y][x])

            reversed_rec_img_stack.append(merge(reversed_block_stack, B, C, H, W, factor=factor, overlap=overlap))
            reversed_bits_stack.append(np.sum(reversed_block_bits_stack))

        for p in range(len(reversed_rec_img_stack)):
            distortion = cal_distoration(reversed_rec_img_stack[p], torch.unsqueeze(input_images[p, :, :, :], 0))
            psnr = cal_psnr(distortion)
            bpp = reversed_bits_stack[p] / (B * H * W)
            rd_cost = cal_rd_cost(distortion, bpp, lambda_for_test)
            print("\n", "P{0} frame after: ".format(p), "mv_bpp:", 0, "\t", "mv_hyper_bpp:",
                  0, "\t", "res_bpp:", 0,
                  "\t", "res_hyper_bpp:", 0, "\t", "bpp:",
                  bpp, "\t", "psnr", psnr, "\t", "pre_psnr", 0, "\t", "rd_cost",
                  rd_cost, "\t")
            sumbpp += bpp
            sumpsnr += psnr
            eval_step += 1

    sumbpp /= eval_step
    sumbpp_real /= eval_step
    sumbpp_mv_y /= (eval_step - gop_num)
    sumbpp_mv_z /= (eval_step - gop_num)
    sumbpp_res_y /= (eval_step - gop_num)
    sumbpp_res_z /= (eval_step - gop_num)
    sumpsnr /= eval_step
    sumpsnr_pre /= (eval_step - gop_num)
    print('\nEpoch {0}  Average MSE={1}  Eval Step={2}\n'.format(str(0), str(avg_loss.data), int(eval_step)))
    log = "HEVC_Class_D  : bpp : %.6lf, mv_y_bpp : %.6lf, mv_z_bpp : %.6lf, " \
          " res_y_bpp : %.6lf, res_z_bpp : %.6lf, psnr : %.6lf, psnr_pre : %.6lf\n" % (
              sumbpp, sumbpp_mv_y, sumbpp_mv_z, sumbpp_res_y, sumbpp_res_z, sumpsnr, sumpsnr_pre)
    print(log)


def parse_args():
    parser = argparse.ArgumentParser(description="Example testing script")

    parser.add_argument('--test_lambdas', type=int, nargs="+", default=(2048, 1024, 512, 256))
    parser.add_argument("--factor", type=int, default=100, help="block size = (factor * 64, factor * 64)")
    parser.add_argument('--overlap',  type=int, default=0, help="overlap area between cropped blocks")
    parser.add_argument('--test_class', type=str, default="HEVC_D")
    parser.add_argument('--gop_size', type=int, default=10)
    parser.add_argument('--test_gop_num', type=int, default=1)

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()

    global test_dataset
    if args.test_class[:3] == "UVG":
        test_dataset = UVGDataSet(gop_size=args.gop_size, test_gop_num=args.test_gop_num)
    elif args.test_class[:4] == "HEVC":
        test_dataset = HEVCDataSet(class_=args.test_class[-1], gop_size=args.gop_size, test_gop_num=args.test_gop_num)
    else:
        print("\nThere is no dataset named", args.test_class, "!")
        exit(-1)
    print("Now the tested dataset is:", args.test_class)

    for i in range(len(args.test_lambdas)):
        lambda_for_test = args.test_lambdas[i]
        print("=========================== testing lambda: {} ====================================".format(lambda_for_test))
        I_codec_checkpoint = torch.load(
            './checkpoints/cheng2020-anchor-{0}.pth.tar'.format(lambda_quality_map[lambda_for_test]),
            map_location=torch.device('cpu'))
        I_codec = architectures['cheng2020-anchor'].from_state_dict(I_codec_checkpoint).cuda()
        model = VideoCompressor()
        pretrained_dict = torch.load("./checkpoints/{0}.model".format(lambda_for_test))
        model.load_state_dict(pretrained_dict, strict=False)
        model = model.cuda()
        print("Number of Total Parameters:", sum(x.numel() for x in model.parameters()))

        test_new_cropped(model, test_dataset, args.factor, args.overlap, write_strame_flag=False, calrealbits=False)
    exit(0)
