"""
proximal step for DEQ
"""
import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import configargparse, argparse
import os, h5py, yaml
import torch.backends.cudnn as cudnn
import sys, json
sys.path.append("../")
# from MRI.dataloader_fastmri import singleCoilFastMRIDataloader
import operators.single_coil_mri as mrimodel
import operators.forward_models as fm
import operators.operator as op
# import MRI.networks_MRI as nets
import networks.network_luser as nets
import torch.optim as optim
import math, time
from networks.network import *
# from utils.training import *
from utils.dataloader import *
from utils.misc import *
# from utils import operators_DEQx2 as functions
import pdb
import operators.operator_deq as deqop
from tqdm import tqdm
# import matplotlib
import utils.parsers as parsers

def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_ct_args(parser)
    parser = parsers.add_luser_args(parser)
    parser = parsers.add_deqip_args(parser)
    parser.add("--file_name", type=str, default="luser_sw/ct/", help="saving directory")
    return parser
parser = make_parser()
args = parser.parse_args()
args.shared_eta = True
print(args)
cuda = True if torch.cuda.is_available() else False
print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
cudnn.benchmark = True

timeStamp = datetime.now().strftime("%Y-%m-%d-%H%M")

""" Load Data """
if not args.eval:
    train_loader, tr_length, val_loader, val_length, ray_trafo, save_path = load_data(args)
else:
    ts_loader, save_path, ts_length, ray_trafo = load_val_data(args)
args.save_path = save_path if not args.eval else save_path + '_val'
print('save_path: ' + args.save_path)
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" NETWORK SETUP """
forward_operator = fm.CT(args.batch_size, ray_trafo)
g = nets.single_layer_CT(num_channel=1, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
deq_inner = nets.DEQFixedPoint(g, deqop.anderson, in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1,
                               m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
invBlock = nets.LUSER_SW_CT(forward_operator, deq_inner, args).to(args.device)
print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

invBlock.train()
opt = torch.optim.Adam(invBlock.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
criteria = nn.MSELoss()

""" LOAD MODEL """
if args.pretrain:
    print('Loading Pre-trained invBlock from', args.load_path)
    invBlock.load_state_dict(torch.load(args.load_path)['state_dict'])

""" BEGIN TRAINING or VALIDATION """
if not args.eval:
    X_val, _ = next(iter(val_loader))
    for epoch in range(args.n_epochs):
        loss_meters = AverageMeter()
        val_meters = AverageMeter()
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for y, X in train_loader:  # (y, X) = next(iter(train_loader))
                bs = X.shape[0]
                X, y = X.to(args.device), y.to(args.device)
                X0 = forward_operator.adjoint(y)
                Xk = torch.clone(X0)

                if args.use_aux_loss:
                    loss = 0.0
                    for k in range(args.maxiters):
                        Xk = invBlock.forward_singlestep(Xk, y, k, train=True)
                        loss_k = criteria(Xk, X)
                        loss += loss_k
                    loss_meters.update(loss_k.item(), bs)
                else:
                    Xk = invBlock(Xk, y, True)
                    loss = criteria(Xk, X)
                    loss_meters.update(loss.item(), bs)
                opt.zero_grad()
                loss.backward()
                opt.step()

                torch.cuda.empty_cache()
                _tqdm.set_postfix({'tr_mse': f'{loss_meters.avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                _tqdm.update(bs)

                # Save the result
                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))

            # Validation...
            with torch.no_grad():
                for y, X in val_loader:
                    bs = X.shape[0]
                    X, y = X.to(args.device), y.to(args.device)
                    X0 = forward_operator.adjoint(y)
                    Xk = torch.clone(X0)  # [bs, 362, 362]
                    Xk = invBlock(Xk, y, True)

                    loss = criteria(Xk, X)
                    val_meters.update(loss.item(), bs)
                    _tqdm.set_postfix({'tr_mse': f'{loss_meters.avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                    _tqdm.update(bs)
                plot_CT(Xk, X0, X, criteria, args.save_path, epoch)

else:
    resize = torchvision.transforms.Resize(args.img_dim)
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]
    with tqdm(total=(ts_length - ts_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        for y, X in ts_loader:  # (X,_) = next(iter(val_loader))
            with torch.no_grad():
                bs = X.shape[0]
                X, y = X.to(args.device), y.to(args.device)
                X0 = forward_operator.adjoint(y)
                Xk = torch.clone(X0)  # [bs, 362, 362]
                Xk = invBlock(Xk, y, True)

                loss = criteria(Xk, X)
                mse = criteria(Xk, X)
                avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics1chan(Xk.unsqueeze(1),
                                                                                               X.unsqueeze(1),
                                                                                               X0.unsqueeze(1))
                criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                for k in range(len_meter):
                    loss_meters[k].update(criteria_list[k].item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                _tqdm.update(bs)
        plot_CT(Xk, X0, X, criteria, args.save_path, 999)