import argparse, json, sys, pdb, configargparse

sys.path.append("../")
import torch.backends.cudnn as cudnn
import operators.forward_models as fm
import networks.network as nets
import torch.optim as optim
from utils.dataloader import *
from utils.misc import *
from tqdm import tqdm
import utils.parsers as parsers

def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_ct_args(parser)
    parser = parsers.add_lu_args(parser)
    parser.add("--file_name", type=str, default="lu/ct/", help="saving directory")
    return parser
parser = make_parser()
args = parser.parse_args()
print(args)
cuda = True if torch.cuda.is_available() else False
print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
cudnn.benchmark = True

if not args.eval:
    train_loader, tr_length, val_loader, val_length, ray_trafo, args.save_path = load_data(args)
else:
    test_loader, args.save_path, ts_length, ray_trafo = load_val_data(args)

print('save_path: ' + args.save_path)
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" Network Setup """
forward_operator = fm.CT(args.batch_size, ray_trafo)  # z_shape: product of batch and channel dimension
DnCNN = nets.DnCNN(channels=1, num_of_layers=17).to(args.device)
invBlock = nets.LU_prox_CT(forward_operator, DnCNN.dncnn, args).to(args.device)
print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

if args.pretrain and os.path.exists(args.load_path):
    checkpoint = torch.load(args.load_path)
    invBlock.load_state_dict(checkpoint['state_dict'])
    print('Model loaded successfully!')
else:
    print('No Model loaded!!')

""" BEGIN TRAINING or TESTING """
if not args.eval:
    print('Begin training...')
    invBlock.train()
    opt = torch.optim.Adam(DnCNN.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
    criteria = nn.MSELoss()

    # X_val, _ = next(iter(val_loader))
    for epoch in range(args.n_epochs):
        num_meters = args.maxiters if args.use_aux_loss else 1
        loss_meters = [AverageMeter() for _ in range(num_meters)]
        val_meters = AverageMeter()
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for y, X in train_loader:  # y, X = next(iter(train_loader))
                bs = X.shape[0]
                X, y = X.to(args.device), y.to(args.device)
                X0 = forward_operator.adjoint(y)
                Xk = torch.clone(X0)

                loss = 0.0
                if args.use_aux_loss:
                    for k in range(args.maxiters):
                        Xk = invBlock.forward_singlestep(Xk, y)
                        loss_k = criteria(Xk, X)
                        loss += loss_k
                        loss_meters[k].update(loss_k.item(), bs)
                    loss += criteria(Xk, X)  # double the weight of final
                else:
                    Xk = invBlock(Xk, y)
                    loss = criteria(Xk, X)
                    loss_meters[0].update(loss.item(), bs)
                opt.zero_grad()
                loss.backward()
                opt.step()

                # loss_meters.update(loss.item(), bs)
                torch.cuda.empty_cache()
                if args.use_aux_loss:
                    dict = {f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(args.maxiters)}
                    dict.update({'ts_mse': f'{val_meters.avg:.6f}'})
                    _tqdm.set_postfix(dict)
                    _tqdm.update(bs)
                else:
                    _tqdm.set_postfix({'tr_mse': f'{loss_meters[-1].avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                    _tqdm.update(bs)

                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))

            with torch.no_grad():
                for y, X in val_loader:
                    bs = X.shape[0]
                    X, y = X.to(args.device), y.to(args.device)
                    X0 = forward_operator.adjoint(y)
                    Xk = invBlock(X0, y)
                    loss = criteria(Xk.squeeze(), X.squeeze())
                    val_meters.update(loss.item(), bs)
                    if args.use_aux_loss:
                        dict = {f'x{k}': f'{loss_meters[k].avg:.6f}' for k in range(args.maxiters)}
                        dict.update({'ts_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.set_postfix(dict)
                        _tqdm.update(bs)
                    else:
                        _tqdm.set_postfix({'tr_mse': f'{loss_meters[-1].avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                        _tqdm.update(bs)
                plot_CT(Xk, X0, X, criteria, args.save_path, epoch)
else:
    print('Testing mode')
    resize = torchvision.transforms.Resize(args.img_dim)
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]
    i = 0
    with tqdm(total=(ts_length - ts_length % args.batch_size_val)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        y, X = next(iter(test_loader))
        with torch.no_grad():
            bs = X.shape[0]
            X, y = X.to(args.device), y.to(args.device)
            X0 = forward_operator.adjoint(y)
            Xk = invBlock(X0, y)
            mse = criteria(Xk.squeeze(), X.squeeze())
            avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics1chan(Xk.unsqueeze(1),
                                                                                           X.unsqueeze(1),
                                                                                           X0.unsqueeze(1))
            criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
            for k in range(len_meter):
                loss_meters[k].update(criteria_list[k].item(), bs)

            torch.cuda.empty_cache()
            _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
            _tqdm.update(bs)
        plot_CT(Xk, X0, X, criteria, args.save_path, 999)
