import torchvision.datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse, json, sys, matplotlib
import os, h5py, yaml, math, time
import torch.backends.cudnn as cudnn

sys.path.append("../")

import operators.operator as op
import operators.single_coil_mri as mrimodel
import networks.network as nets
import torch.optim as optim
from networks.network import *
from utils.dataloader import *
from utils.misc import *
import operators.forward_models as fm
from tqdm import tqdm

import utils.parsers as parsers


def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_celeba_args(parser)
    parser = parsers.add_lu_args(parser)
    parser.add("--file_name", type=str, default="lu/celeba/", help="saving directory")
    return parser


parser = make_parser()
args = parser.parse_args()
print(args)
cuda = True if torch.cuda.is_available() else False
print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
cudnn.benchmark = True

""" LOAD DATA and CREATE SAVING DIRECTORY"""
filestr = "ks" + str(args.kernel_size) + "_sigma" + str(args.sigma)
args.file_name += filestr

train_loader, tr_length, val_loader, val_length, ts_loader, ts_length, save_path = load_data(args)
args.save_path = save_path if not args.eval else save_path + '_val'

print('save_path: ' + args.save_path)
print('load_path: ' + args.load_path)
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" Network Setup """
forward_operator = fm.GaussianBlur(kernel_size=args.kernel_size, sigma=args.sigma, n_channels=3).to(args.device)
measurement_process = op.OperatorPlusNoise(forward_operator, noise_sigma=args.noise_level).to(args.device)

DnCNN = nets.DnCNN(channels=3, num_of_layers=17).to(args.device)
invBlock = nets.LU_prox(forward_operator, DnCNN.dncnn, args).to(args.device)
print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

if args.pretrain and os.path.exists(args.load_path):
    checkpoint = torch.load(args.load_path)
    invBlock.load_state_dict(checkpoint['state_dict'])
    print('Model loaded successfully!')
else:
    print('No Model loaded!!')

""" BEGIN TRAINING or VALIDATION """
if not args.eval:
    print('Begin training...')
    invBlock.train()
    opt = torch.optim.Adam(DnCNN.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
    criteria = nn.MSELoss()

    X_val, _ = next(iter(val_loader))
    for epoch in range(args.n_epochs):
        num_meters = args.maxiters if args.use_aux_loss else 1
        loss_meters = [AverageMeter() for _ in range(num_meters)]
        val_meters = AverageMeter()
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for X, _ in train_loader:  # X, _ = next(iter(train_loader))
                bs = X.shape[0]
                X = X.to(args.device)
                y = measurement_process(X)
                y = torch.clamp(y, min=0, max=1)
                X0 = torch.clone(y)
                Xk = torch.clone(X0)

                loss = 0.0
                if args.use_aux_loss:
                    for k in range(args.maxiters):
                        Xk = invBlock.forward_singlestep(Xk, y)
                        loss_k = criteria(Xk, X)
                        loss += loss_k
                        loss_meters[k].update(loss_k.item(), bs)
                    loss += criteria(Xk, X)  # double the weight of final
                else:
                    Xk = invBlock(Xk, y)
                    loss = criteria(Xk, X)
                    loss_meters[0].update(loss.item(), bs)

                opt.zero_grad()
                loss.backward()
                opt.step()

                # loss_meters.update(loss.item(), bs)
                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(len(loss_meters))})
                _tqdm.update(bs)

                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))
                # break
            with torch.no_grad():
                for X, _ in val_loader:
                    bs = X.shape[0]
                    X = X.to(args.device)
                    y = measurement_process(X)
                    y = torch.clamp(y, min=0, max=1)
                    X0 = torch.clone(y)
                    Xk = invBlock(X0, y)
                    loss = criteria(Xk.squeeze(), X.squeeze())
                    val_meters.update(loss.item(), bs)
                    if args.use_aux_loss:
                        dict = {f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(args.maxiters)}
                        dict.update({'ts_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.set_postfix(dict)
                        _tqdm.update(bs)
                    else:
                        _tqdm.set_postfix({'tr_mse': f'{loss_meters[-1].avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.update(bs)
                plot_CelebA(Xk, X0, X, criteria, args.save_path, epoch)

else:
    print('Testing mode')
    resize = torchvision.transforms.Resize(args.img_dim)
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]
    i = 0
    with tqdm(total=(ts_length - ts_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        for X, _ in ts_loader:
            with torch.no_grad():
                bs = X.shape[0]
                X = X.to(args.device)
                y = measurement_process(X)
                y = torch.clamp(y, min=0, max=1)
                X0 = torch.clone(y)
                Xk = invBlock(X0, y)
                mse = criteria(Xk, X)
                avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics3chan(Xk, X, X0)

                criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
                for k in range(len_meter):
                    loss_meters[k].update(criteria_list[k].item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
                _tqdm.update(bs)
        plot_CelebA(Xk, X0, X, criteria, args.save_path, 999)
