"""
Loop Unrolled Architecture with DEQ as proximal operators, different weights across optimization iterations.
IMAGE DENOISING TASK
"""
import cv2
import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse, json, sys, pdb, configargparse

# import os, h5py, yaml, math, time
sys.path.append("../")

import torch.backends.cudnn as cudnn
import operators.forward_models as fm
import operators.operator as op
import operators.operator_deq as deqop
import operators.single_coil_mri as mrimodel
# import networks.network as nets
import networks.network_luser as nets
import torch.optim as optim
from networks.network import *
from utils.dataloader import *
from utils.misc import *
import matplotlib
from tqdm import tqdm
import utils.parsers as parsers

def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_celeba_args(parser)
    parser = parsers.add_luser_args(parser)
    parser.add("--file_name", type=str, default="luser_dw/celeba/", help="saving directory")
    return parser
parser = make_parser()
args = parser.parse_args()
print(args)
cuda = True if torch.cuda.is_available() else False
print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
cudnn.benchmark = True

""" Load Data """
filestr = "ks" + str(args.kernel_size) + "_sig" + str(args.sigma) + "_var" + str(args.noise_level).replace('.', '')
args.file_name += filestr

train_loader, tr_length, val_loader, val_length, ts_loader, ts_length, save_path = load_data(args)

save_path = save_path if not args.eval else save_path + '_val'
args.save_path = save_path
print('save_path: ' + args.save_path)
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" Network Setup """
forward_operator = fm.GaussianBlur(kernel_size=args.kernel_size, sigma=args.sigma, n_channels=3).to(args.device)
measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)

nBlocks = args.maxiters if args.diffW else 4
factor = args.maxiters // nBlocks
g_list = [nets.single_layer(num_channel=3, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
          for i in range(nBlocks)]
if args.diffW:
    deq_list = [nets.DEQFixedPoint(g_list[i], deqop.anderson, in_channels=3, out_channels=32, kernel_size=3, stride=1,
                                   padding=1, m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters,
                                   beta=args.and_beta) for i in range(args.maxiters)]
else:
    deq_list = [
        nets.DEQFixedPoint(g_list[i // factor], deqop.anderson, in_channels=3, out_channels=32, kernel_size=3, stride=1,
                           padding=1, m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
        for i in range(args.maxiters)]

invBlock = nets.inverse_block_aux(forward_operator, deq_list, args).to(args.device)
print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()) - sum(a.numel() for a in forward_operator.parameters()))

""" Begin Training """
invBlock.train()
forward_operator.eval()
opt = torch.optim.Adam(invBlock.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
criteria = nn.MSELoss()

# LOAD and RESUME TRAINING
if args.pretrain:
    checkpoint = torch.load(args.load_path)
    invBlock.load_state_dict(checkpoint['state_dict'])
    print('Model loaded successfully!')


""" Training """
if not args.eval:
    X_val, _ = next(iter(val_loader))
    for epoch in range(args.n_epochs):
        loss_meters = AverageMeter()
        val_meters = AverageMeter()
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for X, _ in train_loader:  # X, _ = next(iter(train_loader))
                bs = X.shape[0]
                X = X.to(args.device)
                y = measurement_process(X)
                y = torch.clamp(y, min=0, max=1)

                Xk = torch.clone(y)
                X0 = torch.clone(y)
                if args.use_aux_loss:
                    loss = 0.0
                    for i in range(args.maxiters):
                        Xk = invBlock.forward_singlestep(Xk, y, i, train=True)
                        loss_k = criteria(Xk, X)
                        loss += loss_k
                else:
                    Xk = invBlock(Xk, y, train=True)
                    loss = criteria(Xk, X)

                opt.zero_grad()
                loss.backward()
                opt.step()
                loss_meters.update(loss.item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({'tr_mse': f'{loss_meters.avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                _tqdm.update(bs)

                # Save the result within epoch
                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),  # save invBlock.R
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))

            # validation
            if epoch > 0:
                with torch.no_grad():
                    for X, _ in val_loader:
                        bs = X.shape[0]
                        X = X.to(args.device)
                        y = measurement_process(X)
                        y = torch.clamp(y, min=0, max=1)

                        X0 = torch.clone(y)
                        Xk = invBlock(X0, y, train=True)

                        loss = criteria(Xk, X)
                        val_meters.update(loss.item(), bs)
                        _tqdm.set_postfix({'tr_mse': f'{loss_meters.avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.update(bs)
            if (epoch + 1) % 1 == 0:
                plot_CelebA(Xk, X0, X, criteria, args.save_path, epoch)

else:
    resize = torchvision.transforms.Resize(args.img_dim)
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]

    with tqdm(total=(ts_length - ts_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        for X, _ in ts_loader:  # X, _ = next(iter(ts_loader))
            with torch.no_grad():
                bs = X.shape[0]
                X = X.to(args.device)
                y = measurement_process(X)
                y = torch.clamp(y, min=0, max=1)

                X0 = torch.clone(y)
                Xk = invBlock(X0, y, train=True)

                mse = criteria(Xk, X)
                avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics3chan(Xk, X, X0)

                criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                for k in range(len_meter):
                    loss_meters[k].update(criteria_list[k].item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                _tqdm.update(bs)
            plot_CelebA(Xk, X0, X, criteria, args.save_path, 999)
