"""
proximal step for DEQ
"""
import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import configargparse, argparse
import os, h5py, yaml
import torch.backends.cudnn as cudnn
import sys, json
sys.path.append("../")
# from MRI.dataloader_fastmri import singleCoilFastMRIDataloader
import operators.single_coil_mri as mrimodel
import operators.operator as op
# import MRI.networks_MRI as nets
import networks.network_luser as nets
import torch.optim as optim
import math, time
from networks.network import *
# from utils.training import *
from utils.dataloader import *
from utils.misc import *
# from utils import operators_DEQx2 as functions
import pdb
import operators.operator_deq as deqop
from tqdm import tqdm
import utils.parsers as parsers

def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_mri_args(parser)
    parser = parsers.add_luser_args(parser)
    parser.add("--file_name", type=str, default="luser_sw/mri/", help="saving directory")
    return parser
parser = make_parser()
args = parser.parse_args()
args.shared_eta = True

print(args)
cuda = True if torch.cuda.is_available() else False
print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
cudnn.benchmark = True

""" Load Data """
timeStamp = datetime.now().strftime("%Y-%m-%d-%H%M")
if not args.eval:
    train_loader, test_loader, load_path, save_path, tr_length, ts_length = load_data(args)
    val_loader, save_path, val_length = load_val_data(args)
print(val_length)
args.save_path = save_path
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
    print('joined successfully!')
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" Network Setup """

mask = mrimodel.create_mask(shape=[args.img_dim, args.img_dim, 2], acceleration=args.mri_acc, center_fraction=0.04,
                            seed=10)
forward_operator = mrimodel.cartesianSingleCoilMRI(kspace_mask=mask).to(args.device)
measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)

g = nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
deq_inner = nets.DEQFixedPoint(g, deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                                   padding=1, m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters,
                                   beta=args.and_beta)
invBlock = nets.inverse_block_mri_prox(forward_operator, deq_inner, args).to(args.device)
print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

""" Begin Training ..."""
invBlock.train()
opt = torch.optim.Adam(invBlock.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
criteria = nn.MSELoss()

# LOAD and RESUME TRAINING
if args.pretrain:
    print('Loading Pre-trained invBlock from', args.load_path)
    invBlock.load_state_dict(torch.load(args.load_path)['state_dict'])

""" BEGIN TRAINING or VALIDATION """
if not args.eval:
    X_val, _ = next(iter(test_loader))
    resize = torchvision.transforms.Resize(args.img_dim)
    for epoch in range(args.n_epochs):
        loss_meters = [AverageMeter() for _ in range(args.maxiters)]
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for X, _ in train_loader:  # (X,S) = next(iter(dataloader))
                X = X.to(args.device)
                bs = X.shape[0]
                X = op.normalize(X, bs).unsqueeze(1)
                zeros = torch.zeros_like(X)
                X = torch.cat((X, zeros), dim=1).to(args.device)
                y = measurement_process(X).squeeze().detach()

                Xk = forward_operator.adjoint(y).detach()
                loss = 0.0
                for k in range(args.maxiters):
                    Xk = invBlock(Xk, y, k, True)
                    loss_k = criteria(Xk, X)
                    loss_meters[k].update(loss_k.item(), bs)
                    loss += loss_k

                opt.zero_grad()
                loss.backward()
                opt.step()

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(args.maxiters)})
                _tqdm.update(bs)

            # Save the result
            state = {
                'epoch': epoch,
                'state_dict': invBlock.state_dict(),
                'optimizer': opt.state_dict(),
                'scheduler': scheduler.state_dict(),
            }
            torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))
            plot_DEQasProx_result(X_val, args, op, invBlock, forward_operator, measurement_process, criteria, args.save_path, epoch)

else:
    resize = torchvision.transforms.Resize(args.img_dim)
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]
    s = 0
    with tqdm(total=(val_length - val_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        for X, _ in val_loader:  # (X,_) = next(iter(val_loader))
            with torch.no_grad():
                bs = X.shape[0]
                if bs < 4:
                    break
                X = X.to(args.device)
                X = op.normalize(X, bs).unsqueeze(1)
                zeros = torch.zeros_like(X)
                X = torch.cat((X, zeros), dim=1).to(args.device)
                y = measurement_process(X).squeeze()

                X0 = forward_operator.adjoint(y).detach()
                Xk = torch.clone(X0)
                for k in range(args.maxiters):
                    Xk = invBlock(Xk, y, k, True)
                mse = criteria(Xk, X)
                avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics(Xk, X, X0)

                criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                for k in range(len_meter):
                    loss_meters[k].update(criteria_list[k].item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                _tqdm.update(bs)
        (X, _) = next(iter(val_loader))
        plot_DEQasProx_result(X, args, op, invBlock, forward_operator, measurement_process, criteria, args.save_path, 0)