import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse, json, sys, matplotlib
import os, h5py, yaml, math, time
import torch.backends.cudnn as cudnn
sys.path.append("../")

import operators.operator as op
import operators.single_coil_mri as mrimodel
import networks.network as nets
import torch.optim as optim
from networks.network import *
from utils.dataloader import *
from utils.misc import *
from tqdm import tqdm

# matplotlib.use("Qt5Agg")
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=2, help='Batch size, even number')
parser.add_argument('--batch_size_val', type=int, default=4, help='Validation set batch size, even number')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument("--device", type=str, default="cpu", help="cpu or cuda")
parser.add_argument('--noise_level', type=float, default=0.01)
parser.add_argument("--path", type=str, default="../saved_models/", help="network saving directory")
parser.add_argument('--nc', type=int, default=1, help='number of channels in an image')
parser.add_argument('--dataset', type=str, default='MRI', help='dataset: MRI, CT, CelebA, VOC, VOCDetect')

parser.add_argument('--mri_acc', type=float, default=4, help='MRI acceleration, 4 or 8')
parser.add_argument("--file_name", type=str, default="LU_MRI/4x_", help="saving directory")
parser.add_argument('--pretrain', type=bool, default=False, help='pretrain')
parser.add_argument('--train', type=bool, default=True, help='train or eval')
parser.add_argument('--use_aux_loss', type=bool, default=False, help='Use aux loss')
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--load_path', type=str, default='')
parser.add_argument('--maxiters', type=int, default=2, help='Main max iterations')

parser.add_argument('--crop', type=bool, default=False, help='image size = 150 if crop, 320 if not')
parser.add_argument('--n_features', type=int, default=64, help='hidden channels in g')
parser.add_argument('--img_dim', type=int, default=320, help='image size')
parser.add_argument('--eta', type=float, default=1e-4, help='initial eta, WAS 0.01')
parser.add_argument('--lr_gamma', type=float, default=0.2)
parser.add_argument('--sched_step', type=int, default=40)
args = parser.parse_args()
print(args)
cuda = True if torch.cuda.is_available() else False
print('Cuda is available: ' + str(cuda))
cudnn.benchmark = True
random.seed(110)
torch.manual_seed(110)
set_seed(110)

""" LOAD DATA """
if args.train:
    train_loader, test_loader, load_path, save_path, tr_length, ts_length = load_data(args)
else:
    val_loader, save_path, val_length = load_val_data(args)

args.save_path = save_path
print(args.save_path)
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" NETWORK SETUP """
mask = mrimodel.create_mask(shape=[args.img_dim, args.img_dim, 2], acceleration=args.mri_acc, center_fraction=0.04,
                            seed=10)
forward_operator = mrimodel.cartesianSingleCoilMRI(kspace_mask=mask).to(args.device)
measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)

DnCNN = nets.DnCNN(channels=2, num_of_layers=17).to(args.device)
invBlock = nets.LU_prox(forward_operator, DnCNN.dncnn, args).to(args.device)
print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

if args.pretrain and os.path.exists(args.load_path):
    checkpoint = torch.load(args.load_path)
    invBlock.load_state_dict(checkpoint['state_dict'])
    print('Model loaded successfully!')
else:
    print('No Model loaded!!')

""" BEGIN TRAINING or VALIDATION """
if args.train:
    print('Begin training...')
    invBlock.train()
    opt = torch.optim.Adam(DnCNN.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
    criteria = nn.MSELoss()

    resize = torchvision.transforms.Resize(args.img_dim)
    X_val, _ = next(iter(test_loader))
    if args.crop:
        X_val = resize(X_val)
    for epoch in range(args.n_epochs):
        num_meters = args.maxiters if args.use_aux_loss else 1
        loss_meters = [AverageMeter() for _ in range(num_meters)]
        val_meters = AverageMeter()
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for X, _ in train_loader:  # (X,_) = next(iter(train_loader))
                X = X.to(args.device)
                bs = X.shape[0]
                X = op.normalize(X, bs).unsqueeze(1)
                zeros = torch.zeros_like(X)
                X = torch.cat((X, zeros), dim=1).to(args.device)
                y = measurement_process(X).squeeze()

                Xk = forward_operator.adjoint(y)
                loss = 0.0
                if args.use_aux_loss:
                    for k in range(args.maxiters):
                        Xk = invBlock.forward_singlestep(Xk, y)
                        loss_k = criteria(Xk, X)
                        loss += loss_k
                        loss_meters[k].update(loss_k.item(), bs)
                    loss += criteria(Xk, X)  # double the weight of final
                else:
                    Xk = invBlock(Xk, y)
                    loss = criteria(Xk, X)
                    loss_meters[0].update(loss.item(), bs)
                opt.zero_grad()
                loss.backward()
                opt.step()

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(len(loss_meters))})
                _tqdm.update(bs)

            # SAVE RESULTS
            if (epoch + 1) % 1 == 0:
                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))
                plot_LU_result(X_val, args, op, invBlock, forward_operator, measurement_process, criteria, args.save_path, epoch)

            with torch.no_grad():
                for X, _ in val_loader:
                    X = X.to(args.device)
                    bs = X.shape[0]
                    X = op.normalize(X, bs).unsqueeze(1)
                    zeros = torch.zeros_like(X)
                    X = torch.cat((X, zeros), dim=1).to(args.device)
                    y = measurement_process(X).squeeze()
                    Xk = forward_operator.adjoint(y)
                    Xk = invBlock(Xk, y)
                    val_meters.update(loss.item(), bs)
                    if args.use_aux_loss:
                        dict = {f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(args.maxiters)}
                        dict.update({'ts_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.set_postfix(dict)
                        _tqdm.update(bs)
                    else:
                        _tqdm.set_postfix({'tr_mse': f'{loss_meters[-1].avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.update(bs)
else:
    print('validation mode')
    resize = torchvision.transforms.Resize(args.img_dim)
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]
    i = 0
    with tqdm(total=(val_length - val_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        for X, _ in val_loader:  # X,_ = next(iter(val_loader))
            with torch.no_grad():
                bs = X.shape[0]
                if bs < 4:
                    break
                X = op.normalize(X, bs).unsqueeze(1)
                zeros = torch.zeros_like(X)
                X = torch.cat((X, zeros), dim=1).to(args.device)
                y = measurement_process(X).squeeze()
                X0 = forward_operator.adjoint(y)
                Xk = invBlock(torch.clone(X0), y)
                mse = criteria(Xk, X)
                avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics(Xk, X, X0)

                criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                for k in range(len_meter):
                    loss_meters[k].update(criteria_list[k].item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                _tqdm.update(bs)
        X, _ = next(iter(val_loader))
        plot_LU_result(X, args, op, invBlock, forward_operator, measurement_process, criteria, save_path, 0)
