import argparse, json, sys, pdb, configargparse

sys.path.append("../")

import torch.backends.cudnn as cudnn
import operators.forward_models as fm
import operators.operator_deq as deqop
import networks.network_luser as nets
import torch.optim as optim
from networks.network import *
from utils.dataloader import *
from utils.misc import *
import matplotlib
from tqdm import tqdm
import utils.parsers as parsers

def make_parser():
    parser = parsers.make_base_parser()
    parser = parsers.add_ct_args(parser)
    parser = parsers.add_luser_args(parser)
    parser.add("--file_name", type=str, default="luser_dw/ct/", help="saving directory")
    return parser
parser = make_parser()
args = parser.parse_args()
print(args)
cuda = True if torch.cuda.is_available() else False
print(os.getenv('CUDA_VISIBLE_DEVICES'), flush=True)
cudnn.benchmark = True


args.file_name = args.file_name + 'diffW/' if args.diffW else args.file_name + 'shareW/'
if not args.eval:
    train_loader, tr_length, val_loader, val_length, ray_trafo, args.save_path = load_data(args)
else:
    test_loader, args.save_path, ts_length, ray_trafo = load_val_data(args)
print('save_path: ' + args.save_path)
os.makedirs(args.save_path, exist_ok=True)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

""" Network Setup """
forward_operator = fm.CT(args.batch_size, ray_trafo)
nBlocks = args.maxiters if args.diffW else 4
factor = args.maxiters // nBlocks
g_list = [nets.single_layer(num_channel=1, num_features=64, kernel_size=3, stride=1, padding=1, img_dim=args.img_dim)
          for i in range(nBlocks)]

deq_list = [
    nets.DEQFixedPoint(g_list[i // factor], deqop.anderson, in_channels=1, out_channels=32, kernel_size=3, stride=1,
                       padding=1, m=args.and_m, tol=args.and_tol, max_iter=args.and_maxiters, beta=args.and_beta) for i
    in range(args.maxiters)]
invBlock = nets.LUSER_DW_CT(forward_operator, deq_list, args).to(args.device)
print("# Parmeters: ", sum(a.numel() for a in invBlock.parameters()))

""" Begin Training """
invBlock.train()
forward_operator.eval()
opt = torch.optim.Adam(invBlock.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer=opt, step_size=int(args.sched_step), gamma=float(args.lr_gamma))
criteria = nn.MSELoss()

# LOAD and RESUME TRAINING
if args.pretrain:
    checkpoint = torch.load(args.load_path)
    invBlock.load_state_dict(checkpoint['state_dict'])
    print('Model loaded successfully!')


""" Training """
if not args.eval:
    X_val, _ = next(iter(val_loader))
    for epoch in range(args.n_epochs):
        loss_meters = AverageMeter()
        val_meters = AverageMeter()
        with tqdm(total=(tr_length - tr_length % args.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, args.n_epochs))
            for y, X in train_loader:  # y, X = next(iter(train_loader))
                bs = X.shape[0]
                X, y = X.to(args.device), y.to(args.device)
                X0 = forward_operator.adjoint(y)
                Xk = torch.clone(X0)

                if args.use_aux_loss:
                    loss = 0.0
                    for i in range(args.maxiters):
                        Xk = invBlock.forward_singlestep(Xk, y, i, True)
                        loss_k = criteria(Xk, X)
                        loss += loss_k
                else:
                    Xk = invBlock(Xk, y, train=True)
                    loss = criteria(Xk, X)

                opt.zero_grad()
                loss.backward()
                opt.step()
                loss_meters.update(loss.item(), bs)

                torch.cuda.empty_cache()
                _tqdm.set_postfix({'tr_mse': f'{loss_meters.avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                _tqdm.update(bs)

                # Save the result within epoch
                state = {
                    'epoch': epoch,
                    'state_dict': invBlock.state_dict(),  # save invBlock.R
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(state, os.path.join(args.save_path, f'epoch_{epoch}.state'))

            # validation
            with torch.no_grad():
                for y, X in val_loader:
                    bs = X.shape[0]
                    X, y = X.to(args.device), y.to(args.device)
                    X0 = forward_operator.adjoint(y)
                    Xk = invBlock(X0, y, train=True)

                    loss = criteria(Xk, X)
                    val_meters.update(loss.item(), bs)
                    _tqdm.set_postfix({'tr_mse': f'{loss_meters.avg:.6f}', 'val_mse': f'{val_meters.avg:.6f}'})
                    _tqdm.update(bs)
                plot_CT(Xk, X0, X, criteria, args.save_path, epoch)

else:
    criteria = nn.MSELoss()
    criteria_title = ['mse', 'avgInit', 'avgPSNR', 'deltaPSNR', 'avgSSIM']
    len_meter = len(criteria_title)
    loss_meters = [AverageMeter() for _ in range(len_meter)]

    with tqdm(total=(ts_length - ts_length % args.batch_size)) as _tqdm:
        _tqdm.set_description('epoch: {}/{}'.format(1, args.n_epochs))
        y, X = next(iter(test_loader))
        with torch.no_grad():
            bs = X.shape[0]
            X, y = X.to(args.device), y.to(args.device)
            X0 = forward_operator.adjoint(y)
            Xk = invBlock(X0, y, train=True)

            mse = criteria(Xk, X)
            avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim = compute_metrics1chan(Xk.unsqueeze(1), X.unsqueeze(1), X0.unsqueeze(1))
            criteria_list = [mse, avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim]
            for k in range(len_meter):
                loss_meters[k].update(criteria_list[k].item(), bs)

            torch.cuda.empty_cache()
            _tqdm.set_postfix({f'{criteria_title[k]}': f'{loss_meters[k].avg:.6f}' for k in range(len_meter)})
            _tqdm.update(bs)
        plot_CT(Xk, X0, X, criteria, args.save_path, 999)
