"""
2 DEQs:
utils replaces DnCNN, and
utils to solve inverse update
"""
import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import configargparse
import os, h5py, yaml
import torch.backends.cudnn as cudnn
import sys
import json
sys.path.append("../")
import operators.single_coil_mri as mrimodel
import networks.network_luser as nets
from networks.network import DnCNN
import torch.optim as optim
import math, time
from utils.misc import *
from utils.dataloader import *
import operators.operator_deq as deqop
import operators.operator as op
import pdb

import utils.parsers as parsers

def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_mri_args(parser)
    parser = parsers.add_deqip_args(parser)
    parser.add("--file_name", type=str, default="deqip/mri/", help="saving directory")
    return parser

if __name__ == '__main__':
    parser = make_parser()
    args = parser.parse_args()
    print(args)
    cuda = True if torch.cuda.is_available() else False
    print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
    cudnn.benchmark = True
    random.seed(110)
    torch.manual_seed(110)

    timeStamp = datetime.now().strftime("%Y-%m-%d-%H%M")

    """ Load Data """
    # if args.dataset == 'VOC' or args.dataset=='VOCDetect':
    if args.train:
        train_loader, val_loader, load_path, save_path, trainset_size, valset_size = load_data(args)
    else:
        test_loader, save_path, testset_size = load_val_data(args)
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, 'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    """ Forward Model Setup """
    mask = mrimodel.create_mask(shape=[args.img_dim, args.img_dim, 2], acceleration=args.mri_acc, center_fraction=0.04,
                                seed=10)
    forward_operator = mrimodel.cartesianSingleCoilMRI(kspace_mask=mask).to(args.device)
    measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)

    """ Network Setup """
    R = DnCNN(2, num_of_layers=args.n_layers, features=args.n_features)
    deq_inner = nets.inverse_block_mri(forward_operator, R, args)
    inv_model = nets.DEQIPFixedPoint(deq_inner, deqop.anderson,  m=args.and_m, tol=args.and_tol,
                                       max_iter=args.and_maxiters, beta=args.and_beta)

    print("# Parmeters: ", sum(a.numel() for a in inv_model.parameters()))  #


    # LOAD and RESUME TRAINING
    if args.pretrain:
        print('Loading Pre-trained R from', args.load_path)
        inv_model.g.R.load_state_dict(torch.load(args.load_path))


    if args.train:
        X_val,_ = next(iter(val_loader))
        start_epoch=0
        opt = torch.optim.Adam(inv_model.parameters(), lr=args.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
        if args.continue_train:
            checkpoint = torch.load(args.load_path)
            inv_model.load_state_dict(checkpoint['state_dict'])
            opt.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            start_epoch = checkpoint['epoch']
            
        """ Begin Training ..."""
        inv_model.train().to(args.device)
        criteria = nn.MSELoss()
        for epoch in range(start_epoch, args.n_epochs):
            loss_meters = [AverageMeter() for _ in range(1)]#args.maxiters)]
            with tqdm.tqdm(total=(trainset_size - trainset_size % args.batch_size)) as _tqdm:
                _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
                for X, S in train_loader:  # (X,S) = next(iter(dataloader))
        #             S = S.unsqueeze(1).to(args.device)
                    with torch.no_grad():
                        X = X.to(args.device)
                        bs = X.shape[0]
                        X = op.normalize(X, bs).unsqueeze(1)
                        zeros = torch.zeros_like(X)
                        X = torch.cat((X, zeros), dim=1).to(args.device)
                        # input = torch.cat((S, zeros), dim=1).to(args.device)  ### for VOC only
                        y = measurement_process(X).squeeze()
            #             pdb.set_trace()
                        X0 = forward_operator.adjoint(y)
                    Xk = inv_model(y, X0, train=True, zero_init=False)
                    loss = criteria(Xk, X)
                    loss_meters[0].update(loss.item(), bs)


                    opt.zero_grad()
                    loss.backward()
                    opt.step()

                    torch.cuda.empty_cache()

        #             _tqdm.set_postfix(loss='{:.2e}'.format(epoch_losses.avg))
                    _tqdm.set_postfix({f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(1)})#args.maxiters)})
                    _tqdm.update(bs)
        #             break

                # Save the result
                state = {
                    'epoch': epoch,
                    'state_dict': inv_model.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(save_path, f'epoch_{epoch}.state'))

            """ Eval """
        #     X,_ = next(iter(dataloader))

            with torch.no_grad():
                X = X_val.to(args.device)
                X = op.normalize(X, X.size(0)).unsqueeze(1)
                zeros = torch.zeros_like(X)
                X = torch.cat((X, zeros), dim=1).to(args.device)
                # input = torch.cat((S, zeros), dim=1).to(args.device)  ### for VOC only
                y = measurement_process(X).squeeze()
                X0 = forward_operator.adjoint(y)
                Xk = inv_model(y, X0, False, False)

                plt.figure()
                x_hat = torch.clamp(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1).contiguous())), min=0, max=1)
            # x_hat = torch.clamp(Xk[:, 0, :, :], min=0, max=1)
                init = torch.view_as_complex(X0.permute(0, 2, 3, 1).contiguous())
                init_clamp = torch.clamp(torch.abs(init), min=0, max=1)
                X_true = X[:, 0, :, :]
                # X_true = S[:, 0, :, :]
                criteria2 = nn.MSELoss()
        #         pdb.set_trace()
                for i in range(3):
                    plt.subplot(3, 3, i + 1)
                    psnr = 20 * math.log10(1 / math.sqrt(criteria2(init_clamp[i], X_true[i])))
                    plt.imshow(init_clamp[i].detach().cpu(), cmap='gray')
                    plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
                    plt.axis('off')
                    plt.subplot(3, 3, i + 4)
                    psnr = 20 * math.log10(1 / math.sqrt(criteria2(x_hat[i], X_true[i])))
                    plt.imshow(x_hat[i].detach().cpu(), cmap='gray')
                    plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
                    plt.axis('off')
                    plt.subplot(3, 3, i + 7)
                    plt.imshow(X_true[i].detach().cpu(), cmap='gray')
                    plt.title('Clean image')
                    plt.axis('off')

                plt.savefig(os.path.join(save_path, f'{epoch}_results.png'))
                plt.close()
    else:
        X_val,_ = next(iter(test_loader))
        def PSNR(Xk, X):
            criteria2 = nn.MSELoss()
        #     loss = criteria2(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1))), X[:, 0, :, :])
            loss = criteria2(Xk, X)
            return 20 * math.log10(1 / math.sqrt(loss))

        print(inv_model.eta)
        with torch.no_grad():
            X = X_val.to(args.device)
            X = op.normalize(X, X.size(0)).unsqueeze(1)
            zeros = torch.zeros_like(X)
            X = torch.cat((X, zeros), dim=1).to(args.device)
            # input = torch.cat((S, zeros), dim=1).to(args.device)  ### for VOC only
            y = measurement_process(X).squeeze()
            X0 = forward_operator.adjoint(y)
            Xk = inv_model(y, X0, False, False)


            plt.figure()
            x_hat = torch.clamp(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1).contiguous())), min=0, max=1)
            init = torch.view_as_complex(X0.permute(0, 2, 3, 1).contiguous())
            init_clamp = torch.clamp(torch.abs(init), min=0, max=1)
            X_true = X[:, 0, :, :]
            # X_true = S[:, 0, :, :]
            criteria2 = nn.MSELoss()
            for i in range(3):
                plt.subplot(3, 3, i + 1)
                psnr = 20 * math.log10(1 / math.sqrt(criteria2(init_clamp[i], X_true[i])))
                plt.imshow(init_clamp[i].detach().cpu(), cmap='gray')
                plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
                plt.axis('off')
                plt.subplot(3, 3, i + 4)
                psnr = 20 * math.log10(1 / math.sqrt(criteria2(x_hat[i], X_true[i])))
                plt.imshow(x_hat[i].detach().cpu(), cmap='gray')
                plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
                plt.axis('off')
                plt.subplot(3, 3, i + 7)
                plt.imshow(X_true[i].detach().cpu(), cmap='gray')
                plt.title('Clean image')
                plt.axis('off')

            # plt.show()
            plt.savefig(os.path.join(save_path, 'final_results.png'))
        # Testing if the quality improves over iterations.
        for param in inv_model.parameters():
            param.requires_grad = False
        with torch.no_grad():
            inv_model.eval()
            print(param.requires_grad)
