"""
proximal step for DEQ
"""
import torch
import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import configargparse, argparse
import os, h5py, yaml
import torch.backends.cudnn as cudnn
import sys, json
import tqdm
import matplotlib
sys.path.append("../")
import operators.operator as op
# import MRI.networks_MRI as nets
import networks.network_luser as nets
import torch.optim as optim
import math, time
from networks.network import *
# from utils.training import *
import operators.forward_models as fm
from utils.dataloader import *
from utils.misc import *
# from utils import operators_DEQx2 as functions
import pdb
import operators.operator_deq as deqop
import utils.parsers as parsers

def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_celeba_args(parser)
    parser = parsers.add_luser_args(parser)
    parser.add("--file_name", type=str, default="luser_dw/mri/", help="saving directory")
    return parser
parser = make_parser()

if __name__ == '__main__':
    parser = make_parser()
    args = parser.parse_args()
    args.shared_eta = True

    print(args)
    cuda = True if torch.cuda.is_available() else False
    print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
    cudnn.benchmark = True
    random.seed(110)
    torch.manual_seed(110)

    """ Load Data """
    filestr = "ks" + str(args.kernel_size) + "_sigma" + str(args.sigma) + "_var" + str(args.noise_level).replace('.', '')
    args.file_name += filestr
    train_loader, trainset_size, val_loader, valset_size, test_loader, testset_size, save_path = load_data(args)

    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, 'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)
        print('joined successfully!')
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    """ Network Setup """
    #pdb.set_trace()
    forward_operator = fm.GaussianBlur(sigma=args.sigma,
                                       n_channels=args.nc,
                                       kernel_size=args.kernel_size,
                                       ).to(args.device)
    measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)

    g = nets.single_layer(num_channel=args.nc, num_features=64, 
                          kernel_size=3, stride=1, padding=1,img_dim=args.img_dim)
    deq_inner = nets.DEQFixedPoint(g, deqop.anderson, 
                                   in_channels=args.nc, out_channels=32,
                                   kernel_size=3, stride=1, padding=1,
                                   m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, 
                                   beta=args.and_beta)
    invBlock = nets.inverse_block_aux(forward_operator, [deq_inner], args, maxiters=1).to(args.device)
    print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

    """ Begin Training ..."""
    invBlock.train()
    opt = torch.optim.Adam(invBlock.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
    criteria = nn.MSELoss()
    start_epoch=0

    # LOAD and RESUME TRAINING
    if not args.eval:
        # args.load_path = '../saved_models/DEQasprox_deblur/epoch_9.state'
        print('Loading Pre-trained invBlock from', args.load_path)
        invBlock.load_state_dict(torch.load(args.load_path)['state_dict'])
    if args.continue_train and os.path.exists(args.load_path):
        checkpoint = torch.load(args.load_path)
        invBlock.load_state_dict(checkpoint['state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        opt.load_state_dict(checkpoint['optimizer'])
        start_epoch=checkpoint['epoch']
        print('Succesfully loaded training session')

    """ BEGIN TRAINING or VALIDATION """
    if args.train:
        X_val, _ = next(iter(val_loader))
        val_meters = AverageMeter()
        for epoch in range(start_epoch,args.n_epochs):
            loss_meters = [AverageMeter() for _ in range(args.maxiters)]
            val_meters = AverageMeter()
            with tqdm.tqdm(total=(trainset_size - trainset_size % args.batch_size)) as _tqdm:
                _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
                steps = 0
                for X, _ in train_loader:  # (X,_) = next(iter(train_loader))
                    bs = X.shape[0]
                    X = X.to(args.device)
                    y = measurement_process(X)
                    Xk = torch.clamp(y, min=0, max=1).clone()
                    loss = 0.0
                    for k in range(args.maxiters):
                        Xk = invBlock(Xk, y, True)
                        loss_k = criteria(Xk, X)
                        loss_meters[k].update(loss_k.item(), bs)
                        if args.use_aux_loss:
                            loss += loss_k
                    loss += criteria(Xk, X) 

                    opt.zero_grad()
                    loss.backward()
                    opt.step()

                    torch.cuda.empty_cache()
                    dict = {f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(args.maxiters)}
                    dict.update({'ts_mse': f'{val_meters.avg:.6f}'})
                    _tqdm.set_postfix(dict)
                    _tqdm.update(bs)
                    steps += 1
                    # save every 10k steps
                    if steps > 1000:
                        state = {
                            'epoch': epoch,
                            'state_dict': invBlock.state_dict(),
                            'optimizer': opt.state_dict(),
                            'scheduler': scheduler.state_dict(),
                        }
                        torch.save(state, os.path.join(save_path, f'epoch_{epoch}_latest.state'))
                        steps = 0
                # Save the result
                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(save_path, f'epoch_{epoch}.state'))
                with torch.no_grad():
                    for X, _ in val_loader:
                        bs = X.shape[0]
                        X = X.to(args.device)
                        #X = op.normalize(X, bs).unsqueeze(1).to(args.device)
                        y = measurement_process(X)
                        #X0 = forward_operator.adjoint(y)
                        X0 = torch.clamp(y, min=0, max=1).clone()#forward_operator.adjoint(y).detach()
                        Xk = torch.clone(X0)
                        for k in range(args.maxiters):
                            Xk = invBlock(Xk, y, False)
                        loss = criteria(Xk, X)
                        val_meters.update(loss.item(), bs)
                        dict = {f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(args.maxiters)}
                        dict.update({'ts_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.set_postfix(dict)
                        _tqdm.update(bs)

                if (epoch + 1) % 1 == 0:
                    plot_DEQasprox_ca(Xk, X0, X, criteria, save_path, epoch)
    else:
        resize = torchvision.transforms.Resize(args.img_dim)
        criteria = nn.MSELoss()
        criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
        len_meter = len(criteria_title)
        loss_meters = [AverageMeter() for _ in range(len_meter)]
        s = 0
        with tqdm.tqdm(total=(testset_size - testset_size % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
            _tqdm.set_postfix({f'{criteria_title[k]}': f'{0:.6f}' for k in range(len_meter)})
            for X, _ in test_loader:  # (X,_) = next(iter(val_loader))
                with torch.no_grad():
                    bs = X.shape[0]
                    if bs < 4:
                        break
                    bs = X.shape[0]
                    X = X.to(args.device)
                    #X = op.normalize(X, bs).unsqueeze(1).to(args.device)
                    y = measurement_process(X)
                    #X0 = forward_operator.adjoint(y)
                    X0 = torch.clamp(y, min=0, max=1).clone()#forward_operator.adjoint(y).detach()
                    Xk = torch.clone(X0)
                    for k in range(args.maxiters):
                        Xk = invBlock(Xk, y, False)
                    mse = criteria(Xk, X)
                    avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics3chan(Xk, X, X0)

                    criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                    for k in range(len_meter):
                        loss_meters[k].update(criteria_list[k].item(), bs)

                    torch.cuda.empty_cache()
                    _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                    _tqdm.update(bs)
            plot_DEQasprox_ca(Xk, X0, X, criteria, save_path, 0)
