import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse, json, sys, matplotlib
import os, h5py, yaml, math, time
import torch.backends.cudnn as cudnn
sys.path.append("../")

import operators.operator as op
import operators.single_coil_mri as mrimodel
import networks.network as nets
import torch.optim as optim
from networks.network import *
from utils.dataloader import *
from utils.misc import *
import operators.forward_models as fm
from tqdm import tqdm

# matplotlib.use("Qt5Agg")
parser = argparse.ArgumentParser()
parser.add_argument('--A_type', type=str, default='', help='kernel size + sigma')
parser.add_argument('--noise_level', type=float, default=0.01)
parser.add_argument('--maxiters', type=int, default=2, help='Main max iterations')
parser.add_argument('--kernel_size', type=int, default=9)
parser.add_argument('--sigma', type=int, default=5)

parser.add_argument("--file_name", type=str, default="LU_CelebA/", help="saving directory")
parser.add_argument('--batch_size', type=int, default=1, help='Batch size, EVEN number')
parser.add_argument('--use_aux_loss', type=bool, default=False)
parser.add_argument('--pretrain', type=bool, default=False, help='pretrain')
parser.add_argument('--train', type=bool, default=True, help='training or validation mode')
parser.add_argument('--lr', default=1e-4, help='learning rate')
parser.add_argument('--load_path', default='')
parser.add_argument('--save_path', default='')

parser.add_argument("--device", type=str, default="cuda", help="cpu or cuda")
parser.add_argument('--n_epochs', default=200)
parser.add_argument('--eta', type=float, default=1e-4, help='initial eta')
parser.add_argument('--batch_size_val', default=32, help='Validation set batch size, even')
parser.add_argument('--dataset', type=str, default='CelebA', help='MRI, CelebA, MRI')
parser.add_argument("--path", type=str, default="../saved_models/", help="network saving directory")
parser.add_argument('--nc', type=int, default=1, help='number of channels in an image')
parser.add_argument('--lr_gamma', type=float, default=0.2)
parser.add_argument('--sched_step', type=int, default=40)
args = parser.parse_args()
print(args)
cuda = True if torch.cuda.is_available() else False
print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
cudnn.benchmark = True

""" LOAD DATA and CREATE SAVING DIRECTORY"""
args.A_type = "ks" + str(args.kernel_size) + "_sigma" + str(args.sigma)
args.file_name += args.A_type

train_loader, tr_length, val_loader, val_length, ts_loader, ts_length, save_path = load_data(args)
args.save_path = save_path if args.train else save_path + '_val'

print('save_path: ' + args.save_path)
print('load_path: ' + args.load_path)
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" Network Setup """
forward_operator = fm.GaussianBlur(kernel_size=args.kernel_size, sigma=args.sigma, n_channels=3).to(args.device)
measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)

DnCNN = nets.DnCNN(channels=3, num_of_layers=17).to(args.device)
invBlock = nets.LU_prox(forward_operator, DnCNN.dncnn, args).to(args.device)
print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

if args.pretrain and os.path.exists(args.load_path):
    checkpoint = torch.load(args.load_path)
    invBlock.load_state_dict(checkpoint['state_dict'])
    print('Model loaded successfully!')
else:
    print('No Model loaded!!')

""" BEGIN TRAINING or VALIDATION """
if args.train:
    print('Begin training...')
    invBlock.train()
    opt = torch.optim.Adam(DnCNN.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
    criteria = nn.MSELoss()

    X_val, _ = next(iter(val_loader))
    for epoch in range(args.n_epochs):
        num_meters = args.maxiters if args.use_aux_loss else 1
        loss_meters = [AverageMeter() for _ in range(num_meters)]
        val_meters = AverageMeter()
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for X, _ in train_loader:  # X, _ = next(iter(train_loader))
                bs = X.shape[0]
                X = X.to(args.device)
                y = measurement_process(X)
                y = torch.clamp(y, min=0, max=1)
                X0 = torch.clone(y)
                Xk = torch.clone(X0)

                loss = 0.0
                if args.use_aux_loss:
                    for k in range(args.maxiters):
                        Xk = invBlock.forward_singlestep(Xk, y)
                        loss_k = criteria(Xk, X)
                        loss += loss_k
                        loss_meters[k].update(loss_k.item(), bs)
                    loss += criteria(Xk, X)  # double the weight of final
                else:
                    Xk = invBlock(Xk, y)
                    loss = criteria(Xk, X)
                    loss_meters[0].update(loss.item(), bs)

                opt.zero_grad()
                loss.backward()
                opt.step()

                # loss_meters.update(loss.item(), bs)
                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(len(loss_meters))})
                _tqdm.update(bs)

                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))
                # break
            with torch.no_grad():
                for X, _ in val_loader:
                    bs = X.shape[0]
                    X = X.to(args.device)
                    y = measurement_process(X)
                    y = torch.clamp(y, min=0, max=1)
                    X0 = torch.clone(y)
                    Xk = invBlock(X0, y)
                    loss = criteria(Xk.squeeze(), X.squeeze())
                    val_meters.update(loss.item(), bs)
                    if args.use_aux_loss:
                        dict = {f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(args.maxiters)}
                        dict.update({'ts_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.set_postfix(dict)
                        _tqdm.update(bs)
                    else:
                        _tqdm.set_postfix({'tr_mse': f'{loss_meters[-1].avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.update(bs)
                plot_CelebA(Xk, X0, X, criteria, args.save_path, epoch)

else:
    print('Testing mode')
    resize = torchvision.transforms.Resize(args.img_dim)
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]
    i = 0
    with tqdm(total=(ts_length - ts_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        for X, _ in ts_loader:
            with torch.no_grad():
                bs = X.shape[0]
                X = X.to(args.device)
                y = measurement_process(X)
                y = torch.clamp(y, min=0, max=1)
                X0 = torch.clone(y)
                Xk = invBlock(X0, y)
                mse = criteria(Xk, X)
                avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics3chan(Xk, X, X0)

                criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                for k in range(len_meter):
                    loss_meters[k].update(criteria_list[k].item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                _tqdm.update(bs)
            plot_CelebA(Xk, X0, X, criteria, args.save_path, 999)
