"""
2 DEQs:
utils replaces DnCNN, and
utils to solve inverse update
"""
import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import configargparse
import os, h5py, yaml
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn as nn
import sys
import json
sys.path.append("../")
# from MRI.dataloader_fastmri import singleCoilFastMRIDataloader
# import MRI.single_coil_mri as mrimodel
import networks.network_luser as nets
from networks.network import DnCNN
import math, time
from utils.dataloader import *
from utils.misc import AverageMeter, compute_metrics1chan, plot_CT

import operators.forward_models as forwardops
import operators.operator_deq as deqop
import operators.operator as op
import tqdm
from dival import get_standard_dataset

import pdb
import utils.parsers as parsers

parser = parsers.make_base_parser()
parser = parsers.add_ct_args(parser)
parser = parsers.add_deqip_args(parser)
parser.add("--file_name", type=str, default="deqip/ct/", help="saving directory")

args = parser.parse_args()
print(args)
cuda = True if torch.cuda.is_available() else False
print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
cudnn.benchmark = True
random.seed(110)
torch.manual_seed(110)

timeStamp = datetime.now().strftime("%Y-%m-%d-%H%M")

""" Load Data """
# if args.dataset = 'VOC' or args.dataset=='VOCDetect':
if args.train:
    train_loader, tr_length, val_loader, val_length, ray_trafo, save_path = load_data(args)
    args.save_path = save_path if args.train else save_path + '_val'
else:
    ts_loader, save_path,ts_length, ray_trafo  = load_val_data(args)
    args.save_path = save_path

os.makedirs(save_path, exist_ok=True)
with open(os.path.join(save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)

args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


""" Forward Model Setup """
forward_operator = forwardops.CT(args.batch_size, ray_trafo)

# g = nets.single_layer_rgb(channels=2, img_dim=args.img_dim, num_of_layers=4)
# deq_inner = nets.DEQFixedPoint_x2(g, op.anderson, m=args.and_m, tol=args.and_tol,
#                                        max_iter=args.and_maxiters, beta=args.and_beta)

""" Network Setup """
R = DnCNN(args.nc, num_of_layers=args.n_layers)
deq_inner = nets.inverse_block_mri(forward_operator, R, args)
inv_model = nets.DEQIPFixedPoint(deq_inner, deqop.anderson,  m=args.and_m, tol=args.and_tol,
                                   max_iter=args.and_maxiters, beta=args.and_beta)

print("# Parmeters: ", sum(a.numel() for a in inv_model.parameters()))  #


# criteria = nn.MSELoss(reduction='sum')
# inv_model2.train()
# opt2 = torch.optim.Adam(inv_model2.parameters(), lr=args.lr)


if args.pretrain:
    print('Loading Pre-trained R from', args.load_path)
    inv_model.g.R.load_state_dict(torch.load(args.load_path))
      
if args.train:
    # LOAD and RESUME TRAINING
    inv_model.train().to(args.device)
    opt = torch.optim.Adam(inv_model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
    criteria = nn.MSELoss()
    start_epoch=0
    if args.continue_train:
        checkpoint = torch.load(args.load_path)
        inv_model.load_state_dict(checkpoint['state_dict'])
        opt.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        start_epoch = checkpoint['epoch']
        print('Model loaded successfully')
    val_loss = 0.0
        
    """ Begin Training ..."""
    for epoch in range(start_epoch, args.n_epochs):
        loss_meters = [AverageMeter() for _ in range(1)]#args.maxiters)]
        with tqdm.tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            count = 0
            for y, X in train_loader:  # (X,S) = next(iter(dataloader))
    #             S = S.unsqueeze(1).to(args.device)
                with torch.no_grad():
                    #X, y = X.to(args.device), y.to(args.device)
                    X, y = X.to(args.device).unsqueeze(1), y.to(args.device).unsqueeze(1)
                    bs = X.shape[0]
                    X0 = forward_operator.adjoint(y) # should be 4 channel
                Xk = inv_model(y, X0, train=True, zero_init=False)
                loss = criteria(Xk, X)
                loss_meters[0].update(loss.item(), bs)


                opt.zero_grad()
                loss.backward()
                opt.step()

                torch.cuda.empty_cache()

    #             _tqdm.set_postfix(loss='{:.2e}'.format(epoch_losses.avg))
                _tqdm.set_postfix({**{f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(1)},'val':f'{val_loss:.6f}'})#args.maxiters)})
                _tqdm.update(bs)
                count += 1
                if count >= 500:
                    count = 0
                    state = {
                        'epoch': epoch,
                        'state_dict': inv_model.state_dict(),
                        'optimizer': opt.state_dict(),
                        'scheduler': scheduler.state_dict(),
                    }
                    torch.save(state, os.path.join(save_path, f'epoch_{epoch}_latest.state'))

            # Save the result
            state = {
                'epoch': epoch,
                'state_dict': inv_model.state_dict(),
                'optimizer': opt.state_dict(),
                'scheduler': scheduler.state_dict(),
            }
            torch.save(state, os.path.join(save_path, f'epoch_{epoch}.state'))

        """ Eval """
        val_meter = AverageMeter()
        save=True
        with torch.no_grad():
            with tqdm.tqdm(total=(val_length - val_length % args.batch_size)) as val_tqdm:
                val_tqdm.set_description('Val epoch: {}/{}'.format(epoch + 1, args.n_epochs))
                for y, X in val_loader:  # (X,S) = next(iter(dataloader))
        #             S = S.unsqueeze(1).to(args.device)
                    X, y = X.to(args.device).unsqueeze(1), y.to(args.device).unsqueeze(1)
                    bs = X.shape[0]
                    X0 = forward_operator.adjoint(y) # should be 4 channel
                    Xk = inv_model(y, X0, train=False, zero_init=False)
                    loss = criteria(Xk, X)
                    val_meter.update(loss.item(), bs)
                    val_tqdm.set_postfix({'val':f'{val_meter.avg:.6f}'})#args.maxiters)})
                    val_tqdm.update(bs)
            val_loss = val_meter.avg
            plot_CT(Xk[:,0], X0[:,0], X[:,0], criteria, args.save_path, epoch)
            # Save Image Example
            #if (epoch+1) % 10 == 0: 
#                 plt.figure()
#                 x_hat = torch.clamp(Xk, min=0, max=1).squeeze(1)
#                 init_clamp = torch.clamp(X0, min=0, max=1).squeeze(1)
#                 X_true = X[:, 0, :, :]
#                 criteria2 = nn.MSELoss()
#                 for i in range(3):
#                     plt.subplot(3, 3, i + 1)
#                     psnr = 20 * math.log10(1 / math.sqrt(criteria2(init_clamp[i], X_true[i])))
#                     plt.imshow(init_clamp[i].detach().cpu(), cmap='gray')
#                     plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
#                     plt.axis('off')
#                     plt.subplot(3, 3, i + 4)
#                     psnr = 20 * math.log10(1 / math.sqrt(criteria2(x_hat[i], X_true[i])))
#                     plt.imshow(x_hat[i].detach().cpu(), cmap='gray')
#                     plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
#                     plt.axis('off')
#                     plt.subplot(3, 3, i + 7)
#                     plt.imshow(X_true[i].detach().cpu(), cmap='gray')
#                     plt.title('Clean image')
#                     plt.axis('off')

#                 # plt.show()
#                 plt.savefig(os.path.join(save_path, f'{epoch}_results.png'))
#                 plt.close()


else:
    checkpoint = torch.load(args.load_path)
    inv_model.load_state_dict(checkpoint['state_dict'])
    inv_model.to(args.device)
    """ Eval """
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]

    ts_length = len(ts_loader.dataset)
    with tqdm.tqdm(total=(ts_length - ts_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        for y, X in ts_loader:  # X, _ = next(iter(ts_loader))
        #             S = S.unsqueeze(1).to(args.device)
            with torch.no_grad():
                bs = X.shape[0]
                X, y = X.to(args.device).unsqueeze(1), y.to(args.device).unsqueeze(1)
                X0 = forward_operator.adjoint(y) # should be 4 channel
                Xk = inv_model(y, X0, train=False, zero_init=False)

                mse = criteria(Xk, X)
                avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics1chan(Xk, X, X0)

                criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                for k in range(len_meter):
                    loss_meters[k].update(criteria_list[k].item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                _tqdm.update(bs)
        plot_CT(Xk[:,0], X0[:,0], X[:,0], criteria, args.save_path, 0)

# def PSNR(Xk, X):
#     criteria2 = nn.MSELoss()
# #     loss = criteria2(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1))), X[:, 0, :, :])
#     loss = criteria2(Xk, X)
#     return 20 * math.log10(1 / math.sqrt(loss))



# X_val,_ = next(iter(valloader))
# print(inv_model.eta)
# with torch.no_grad():
#     X = X_val.to(args.device)
# #     X = op.normalize(X, X.size(0)).unsqueeze(1)
#     y = measurement_process(X)
#     X0 = torch.clamp(y, min=0, max=1).clone() #forward_operator.adjoint(y)
#     Xk = inv_model(y, X0, False, False)


#     plt.figure()
#     x_hat = torch.clamp(Xk, min=0, max=1).squeeze(1)
#     init_clamp = torch.clamp(X0, min=0, max=1).squeeze(1)
#     X_true = X[:, 0, :, :]
#     criteria2 = nn.MSELoss()
#     for i in range(3):
#         plt.subplot(3, 3, i + 1)
#         psnr = 20 * math.log10(1 / math.sqrt(criteria2(init_clamp[i], X_true[i])))
#         plt.imshow(init_clamp[i].detach().cpu(), cmap='gray')
#         plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
#         plt.axis('off')
#         plt.subplot(3, 3, i + 4)
#         psnr = 20 * math.log10(1 / math.sqrt(criteria2(x_hat[i], X_true[i])))
#         plt.imshow(x_hat[i].detach().cpu(), cmap='gray')
#         plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
#         plt.axis('off')
#         plt.subplot(3, 3, i + 7)
#         plt.imshow(X_true[i].detach().cpu(), cmap='gray')
#         plt.title('Clean image')
#         plt.axis('off')

#     # plt.show()
#     plt.savefig(os.path.join(save_path, 'final_results.png'))
# # Testing if the quality improves over iterations.
# for param in inv_model.parameters():
#     param.requires_grad = False
# with torch.no_grad():
#     inv_model.eval()
#     print(param.requires_grad)


