"""
2 DEQs:
utils replaces DnCNN, and
utils to solve inverse update
"""
import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import configargparse
import os, h5py, yaml
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn as nn
import sys
import json
sys.path.append("../")
# from MRI.dataloader_fastmri import singleCoilFastMRIDataloader
# import MRI.single_coil_mri as mrimodel
import networks.network_luser as nets
from networks.network import DnCNN
import math, time
from utils.dataloader import *
from utils.misc import AverageMeter, compute_metrics3chan, plot_CelebA

import operators.forward_models as forwardops
import operators.operator_deq as deqop
import operators.operator as op
from tqdm import tqdm

import pdb
import utils.parsers as parsers

def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_celeba_args(parser)
    parser = parsers.add_deqip_args(parser)
    parser.add("--file_name", type=str, default="deqip/celeba/", help="saving directory")
    return parser

if __name__=='__main__':
    parser = make_parser()
    args = parser.parse_args()
    print(args)
    cuda = True if torch.cuda.is_available() else False
    print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
    cudnn.benchmark = True
    random.seed(110)
    torch.manual_seed(110)

    timeStamp = datetime.now().strftime("%Y-%m-%d-%H%M")

    """ Load Data """
    train_loader, trainset_size, val_loader, valset_size, test_loader, testset_size, save_path = load_data(args)
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, 'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    """ Forward Model Setup """
    args.A_type = "ks" + str(args.kernel_size) + "_sigma" + str(args.sigma) + "_var" + str(args.noise_level).replace('.', '')
    args.file_name += args.A_type

    forward_operator = forwardops.GaussianBlur(sigma=args.sigma, kernel_size=args.kernel_size, n_channels=args.nc).to(args.device)

    measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)


    """ Network Setup """
    R = DnCNN(args.nc, num_of_layers=args.n_layers, features=args.n_features)
    deq_inner = nets.inverse_block_mri(forward_operator, R, args)
    inv_model = nets.DEQIPFixedPoint(deq_inner, deqop.anderson,  m=args.and_m, tol=args.and_tol,
                                       max_iter=args.and_maxiters, beta=args.and_beta)

    print("# Parmeters: ", sum(a.numel() for a in inv_model.parameters()))  #


    if args.pretrain:
        print('Loading Pre-trained R from', args.load_path)
        inv_model.g.R.load_state_dict(torch.load(args.load_path))

    if args.train:
        # LOAD and RESUME TRAINING
        inv_model.train().to(args.device)
        opt = torch.optim.Adam(inv_model.parameters(), lr=args.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
        criteria = nn.MSELoss()
        if args.continue_train:
            checkpoint = torch.load(args.load_path)
            inv_model.load_state_dict(checkpoint['state_dict'])
            opt.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            epoch = checkpoint['epoch']
        val_loss = 0.0

        """ Begin Training ..."""
        for epoch in range(args.n_epochs):
            loss_meters = [AverageMeter() for _ in range(1)]#args.maxiters)]
            with tqdm(total=(trainset_size - trainset_size % args.batch_size)) as _tqdm:
                _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
                for X, _ in train_loader:  # (X,S) = next(iter(dataloader))
                    with torch.no_grad():
                        X = X.to(args.device)
                        bs = X.shape[0]
                        y = measurement_process(X)
                        X0 = torch.clamp(y, min=0, max=1).clone()#forward_operator.adjoint(y)
                    Xk = inv_model(y, X0, train=True, zero_init=False)
                    loss = criteria(Xk, X)
                    loss_meters[0].update(loss.item(), bs)

                    opt.zero_grad()
                    loss.backward()
                    opt.step()

                    torch.cuda.empty_cache()

                    _tqdm.set_postfix({**{f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(1)},'val':f'{val_loss:.6f}'})#args.maxiters)})
                    _tqdm.update(bs)

                # Save the result
                state = {
                    'epoch': epoch,
                    'state_dict': inv_model.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(save_path, f'epoch_{epoch}.state'))

            """ Eval """
            val_meter = AverageMeter()
            save=True
            with torch.no_grad():
                with tqdm(total=(valset_size - valset_size % args.batch_size)) as val_tqdm:
                    val_tqdm.set_description('Val epoch: {}/{}'.format(epoch + 1, args.n_epochs))
                    for X, S in val_loader:  # (X,S) = next(iter(dataloader))
                        X = X.to(args.device)
                        bs = X.shape[0]
                        y = measurement_process(X)
                        X0 = torch.clamp(y, min=0, max=1).clone()# forward_operator.adjoint(y)
                        Xk = inv_model(y, X0, train=False, zero_init=False)
                        loss = criteria(Xk, X)
                        val_meter.update(loss.item(), bs)
                        val_tqdm.set_postfix({'val':f'{val_meter.avg:.6f}'})#args.maxiters)})
                        val_tqdm.update(bs)
                val_loss = val_meter.avg
                # Save Image Example
                if (epoch+1) % 10 == 0: 
                    plt.figure()
                    x_hat = torch.clamp(Xk, min=0, max=1).squeeze(1)
                    init_clamp = torch.clamp(X0, min=0, max=1).squeeze(1)
                    X_true = X[:, 0, :, :]
                    criteria2 = nn.MSELoss()
                    for i in range(3):
                        plt.subplot(3, 3, i + 1)
                        psnr = 20 * math.log10(1 / math.sqrt(criteria2(init_clamp[i], X_true[i])))
                        plt.imshow(init_clamp[i].detach().cpu(), cmap='gray')
                        plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
                        plt.axis('off')
                        plt.subplot(3, 3, i + 4)
                        psnr = 20 * math.log10(1 / math.sqrt(criteria2(x_hat[i], X_true[i])))
                        plt.imshow(x_hat[i].detach().cpu(), cmap='gray')
                        plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
                        plt.axis('off')
                        plt.subplot(3, 3, i + 7)
                        plt.imshow(X_true[i].detach().cpu(), cmap='gray')
                        plt.title('Clean image')
                        plt.axis('off')

                    # plt.show()
                    plt.savefig(os.path.join(save_path, f'{epoch}_results.png'))
                    plt.close()


    else:
        checkpoint = torch.load(args.load_path)
        inv_model.load_state_dict(checkpoint['state_dict'])
        inv_model.to(args.device)
        """ Eval """
        criteria = nn.MSELoss()
        criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
        len_meter = len(criteria_title)
        loss_meters = [AverageMeter() for _ in range(len_meter)]

        with tqdm(total=(testset_size - testset_size % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
            for X, _ in test_loader:  # X, _ = next(iter(ts_loader))
                with torch.no_grad():
                    bs = X.shape[0]
                    X = X.to(args.device)
                    y = measurement_process(X)
                    y = torch.clamp(y, min=0, max=1)

                    X0 = torch.clone(y)
                    Xk = inv_model(y, X0, False, False)

                    mse = criteria(Xk, X)
                    avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics3chan(Xk, X, X0)

                    criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                    for k in range(len_meter):
                        loss_meters[k].update(criteria_list[k].item(), bs)

                    torch.cuda.empty_cache()
                    _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                    _tqdm.update(bs)
            plot_CelebA(Xk, X0, X, criteria, args.save_path, 999)

