import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse, json, sys, pdb, configargparse

# import os, h5py, yaml, math, time
sys.path.append("../")

import torch.backends.cudnn as cudnn
import operators.operator as op
import operators.operator_deq as deqop
import operators.single_coil_mri as mrimodel
# import networks.network as nets
import networks.network_luser as nets
import torch.optim as optim
from networks.network import *
from utils.dataloader import *
from utils.misc import *
import matplotlib
from tqdm import tqdm
import utils.parsers as parsers


def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_mri_args(parser)
    parser = parsers.add_luser_args(parser)
    # parser = parsers.add_deqip_args(parser)
    parser.add("--file_name", type=str, default="luser_dw/mri/", help="saving directory")
    return parser


parser = make_parser()
args = parser.parse_args()
print(args)
cuda = True if torch.cuda.is_available() else False
print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
cudnn.benchmark = True

""" Load Data """
if not args.eval:
    train_loader, val_loader, load_path, save_path, tr_length, val_length = load_data(args)
else:
    ts_loader, save_path, ts_length = load_val_data(args)

args.save_path = save_path
print('save_path = ', args.save_path)
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" Network Setup """
mask = mrimodel.create_mask(shape=[args.img_dim, args.img_dim, 2], acceleration=args.mri_acc, center_fraction=0.04,
                            seed=10)
forward_operator = mrimodel.cartesianSingleCoilMRI(kspace_mask=mask).to(args.device)
measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)

if args.use_aux_loss:
    nBlocks = args.maxiters if args.diffW else 4
    factor = args.maxiters // nBlocks
    g_list = [
        nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
        for i in range(nBlocks)]

    if args.diffW:
        deq_list = [
            nets.DEQFixedPoint(g_list[i], deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                               padding=1, m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters,
                               beta=args.and_beta) for i in range(args.maxiters)]
    else:
        deq_list = [
            nets.DEQFixedPoint(g_list[i // factor], deqop.anderson, in_channels=2, out_channels=32, kernel_size=3,
                               stride=1,
                               padding=1, m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters,
                               beta=args.and_beta)
            for i in range(args.maxiters)]
    invBlock = nets.inverse_block_full(forward_operator, deq_list, args).to(args.device)
    print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

else:
    g1 = nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
    g2 = nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
    g3 = nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
    g4 = nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
    g5 = nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
    g6 = nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
    g7 = nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
    g8 = nets.single_layer(num_channel=2, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)

    deq1 = nets.DEQFixedPoint_MRI(g1, deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                                  padding=1,
                                  m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
    deq2 = nets.DEQFixedPoint_MRI(g2, deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                                  padding=1,
                                  m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
    deq3 = nets.DEQFixedPoint_MRI(g3, deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                                  padding=1,
                                  m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
    deq4 = nets.DEQFixedPoint_MRI(g4, deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                                  padding=1,
                                  m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
    deq5 = nets.DEQFixedPoint_MRI(g5, deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                                  padding=1,
                                  m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
    deq6 = nets.DEQFixedPoint_MRI(g6, deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                                  padding=1,
                                  m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
    deq7 = nets.DEQFixedPoint_MRI(g7, deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                                  padding=1,
                                  m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
    deq8 = nets.DEQFixedPoint_MRI(g8, deqop.anderson, in_channels=2, out_channels=32, kernel_size=3, stride=1,
                                  padding=1,
                                  m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta)
    invBlock = nets.inverse_block_reverse_full(forward_operator, deq1, deq2, deq3, deq4, deq5, deq6, deq7, deq8,
                                               args).to(args.device)
    print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

""" Begin Training """
invBlock.train()
opt = torch.optim.Adam(invBlock.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
criteria = nn.MSELoss()

# LOAD and RESUME TRAINING
if args.pretrain:
    print('loaded model: ', args.load_path)
    checkpoint = torch.load(args.load_path)
    invBlock.load_state_dict(checkpoint['state_dict'])
    print('Model loaded successfully!')
else:
    print('no model loaded')

""" Training or Eval """
if not args.eval:
    """ Training Mode """
    resize = torchvision.transforms.Resize(args.img_dim)
    X_val, _ = next(iter(val_loader))
    for epoch in range(args.n_epochs):
        loss_meters = AverageMeter()
        val_meters = AverageMeter()
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for X, _ in train_loader:
                bs = X.shape[0]
                if bs % 2 == 1: break
                X = op.normalize(X, bs).unsqueeze(1).to(args.device)
                zeros = torch.zeros_like(X)
                X = torch.cat((X, zeros), dim=1)
                y = measurement_process(X).squeeze()

                X0 = forward_operator.adjoint(y).detach()
                Xk = torch.clone(X0)

                if args.use_aux_loss:
                    loss = 0.0
                    for i in range(args.maxiters):
                        Xk = invBlock.forward_singlestep(Xk, y, i, train=True)
                        loss_k = criteria(Xk, X)
                        loss += loss_k
                    loss_meters.update(loss_k.item(), bs)
                else:
                    Xk = invBlock(Xk, y, train=True)
                    loss = criteria(Xk, X)
                    loss_meters.update(loss.item(), bs)

                opt.zero_grad()
                loss.backward()
                opt.step()

                torch.cuda.empty_cache()
                _tqdm.set_postfix({'tr_mse': f'{loss_meters.avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                _tqdm.update(bs)

                # Save the result
                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))

            """ Validation """
            with torch.no_grad():
                for X, _ in val_loader:
                    bs = X.shape[0]
                    X = op.normalize(X, bs).unsqueeze(1).to(args.device)
                    zeros = torch.zeros_like(X)
                    X = torch.cat((X, zeros), dim=1)
                    y = measurement_process(X).squeeze()

                    X0 = forward_operator.adjoint(y).detach()
                    Xk = torch.clone(X0)
                    Xk = invBlock(Xk, y, train=True)
                    loss = criteria(Xk, X)

                    val_meters.update(loss.item(), bs)
                    _tqdm.set_postfix({'tr_mse': f'{loss_meters.avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                    _tqdm.update(bs)
                plot_MRI(Xk, X0, X, criteria, args.save_path, epoch)

else:
    """ Evaluation Mode """
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]

    with tqdm(total=(ts_length - ts_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        for X, _ in ts_loader:
            with torch.no_grad():
                bs = X.shape[0]
                X = op.normalize(X, bs).unsqueeze(1).to(args.device)
                zeros = torch.zeros_like(X)
                X = torch.cat((X, zeros), dim=1)
                y = measurement_process(X).squeeze()

                X0 = forward_operator.adjoint(y).detach()
                Xk = torch.clone(X0)
                Xk = invBlock(Xk, y, train=True)

                mse = criteria(Xk, X)
                avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics(Xk, X, X0)

                criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                for k in range(len_meter):
                    loss_meters[k].update(criteria_list[k].item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                _tqdm.update(bs)
        plot_MRI(Xk, X0, X, criteria, args.save_path, 999)
