import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse, json, sys, matplotlib
import os, h5py, yaml, math, time
import torch.backends.cudnn as cudnn
sys.path.append("../")

import operators.operator as op
import operators.single_coil_mri as mrimodel
import networks.network as nets
import torch.optim as optim
from networks.network import *
from utils.dataloader import *
from utils.misc import *
from tqdm import tqdm
import utils.parsers as parsers

def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_mri_args(parser)
    parser = parsers.add_lu_args(parser)
    parser.add("--file_name", type=str, default="lu/mri/", help="saving directory")
    return parser

parser = make_parser()
args = parser.parse_args()
print(args)
cuda = True if torch.cuda.is_available() else False
print('Cuda is available: ' + str(cuda))
cudnn.benchmark = True
random.seed(110)
torch.manual_seed(110)
set_seed(110)


""" LOAD DATA """
if not args.eval:
    train_loader, val_loader, load_path, args.save_path, tr_length, val_length = load_data(args)
else:
    test_loader, args.save_path, ts_length = load_val_data(args)

print(args.save_path)
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" NETWORK SETUP """
mask = mrimodel.create_mask(shape=[args.img_dim, args.img_dim, 2], acceleration=args.mri_acc, center_fraction=0.04,
                            seed=10)
forward_operator = mrimodel.cartesianSingleCoilMRI(kspace_mask=mask).to(args.device)
measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)

DnCNN = nets.DnCNN(channels=2, num_of_layers=17).to(args.device)
invBlock = nets.LU_prox(forward_operator, DnCNN.dncnn, args).to(args.device)
print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

if args.pretrain and os.path.exists(args.load_path):
    checkpoint = torch.load(args.load_path)
    invBlock.load_state_dict(checkpoint['state_dict'])
    print('Model loaded successfully!')
else:
    print('No Model loaded!!')

""" BEGIN TRAINING or VALIDATION """
if not args.eval:
    print('Begin training...')
    invBlock.train()
    opt = torch.optim.Adam(DnCNN.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
    criteria = nn.MSELoss()

    resize = torchvision.transforms.Resize(args.img_dim)
    X_val, _ = next(iter(val_loader))
    for epoch in range(args.n_epochs):
        num_meters = args.maxiters if args.use_aux_loss else 1
        loss_meters = [AverageMeter() for _ in range(num_meters)]
        val_meters = AverageMeter()
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for X, _ in train_loader:  # (X,_) = next(iter(train_loader))
                X = X.to(args.device)
                bs = X.shape[0]
                X = op.normalize(X, bs).unsqueeze(1)
                zeros = torch.zeros_like(X)
                X = torch.cat((X, zeros), dim=1).to(args.device)
                y = measurement_process(X).squeeze()

                Xk = forward_operator.adjoint(y)
                loss = 0.0
                if args.use_aux_loss:
                    for k in range(args.maxiters):
                        Xk = invBlock.forward_singlestep(Xk, y)
                        loss_k = criteria(Xk, X)
                        loss += loss_k
                        loss_meters[k].update(loss_k.item(), bs)
                    loss += criteria(Xk, X)  # double the weight of final
                else:
                    Xk = invBlock(Xk, y)
                    loss = criteria(Xk, X)
                    loss_meters[0].update(loss.item(), bs)
                opt.zero_grad()
                loss.backward()
                opt.step()

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(len(loss_meters))})
                _tqdm.update(bs)

            # SAVE RESULTS
            if (epoch + 1) % 1 == 0:
                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))
                plot_LU_result(X_val, args, op, invBlock, forward_operator, measurement_process, criteria, args.save_path, epoch)

            with torch.no_grad():
                for X, _ in val_loader:
                    X = X.to(args.device)
                    bs = X.shape[0]
                    X = op.normalize(X, bs).unsqueeze(1)
                    zeros = torch.zeros_like(X)
                    X = torch.cat((X, zeros), dim=1).to(args.device)
                    y = measurement_process(X).squeeze()
                    Xk = forward_operator.adjoint(y)
                    Xk = invBlock(Xk, y)
                    val_meters.update(loss.item(), bs)
                    if args.use_aux_loss:
                        dict = {f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(args.maxiters)}
                        dict.update({'ts_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.set_postfix(dict)
                        _tqdm.update(bs)
                    else:
                        _tqdm.set_postfix({'tr_mse': f'{loss_meters[-1].avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.update(bs)
else:
    print('validation mode')
    resize = torchvision.transforms.Resize(args.img_dim)
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]
    i = 0
    with tqdm(total=(ts_length - ts_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        for X, _ in test_loader:  # X,_ = next(iter(test_loader))
            with torch.no_grad():
                bs = X.shape[0]
                if bs < 4:
                    break
                X = op.normalize(X, bs).unsqueeze(1)
                zeros = torch.zeros_like(X)
                X = torch.cat((X, zeros), dim=1).to(args.device)
                y = measurement_process(X).squeeze()
                X0 = forward_operator.adjoint(y)
                Xk = invBlock(torch.clone(X0), y)
                mse = criteria(Xk, X)
                avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics(Xk, X, X0)

                criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                for k in range(len_meter):
                    loss_meters[k].update(criteria_list[k].item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                _tqdm.update(bs)
        X, _ = next(iter(test_loader))
        plot_LU_result(X, args, op, invBlock, forward_operator, measurement_process, criteria, args.save_path, 0)
