import sys, os, hashlib
sys.path.insert(0, 'external/fastMRI')

import pathlib
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateLogger
import torch
import torch.backends.cudnn as cudnn
from torch import nn
import random
import numpy as np

from training_utils.mri_model import MRIModel
import training_utils.data_augment as DA

from models.varnet.varnet import SensitivityModel, NormUnet, SSIM
from common.args import Args
from argparse import Namespace
from common.subsample import create_mask_for_mask_type
from data import transforms as T
from utils import load_args

class VarNetBlock(nn.Module):
    def __init__(self, model):
        super(VarNetBlock, self).__init__()
        self.model = model
        self.dc_weight = nn.Parameter(torch.ones(1))
        self.register_buffer('zero', torch.zeros(1, 1, 1, 1, 1))

    def forward(self, current_kspace, ref_kspace, mask, sens_maps=None):
        def sens_expand(x):
            return T.fft2(T.complex_mul(x, sens_maps))

        def sens_reduce(x):
            x = T.ifft2(x)
            return T.complex_mul(x, T.complex_conj(sens_maps)).sum(dim=1, keepdim=True)

        def soft_dc(x):
            return torch.where(mask, x - ref_kspace, self.zero) * self.dc_weight
        
        if sens_maps is not None:
            # Multi-coil
            return current_kspace - \
                soft_dc(current_kspace) - \
                sens_expand(self.model(sens_reduce(current_kspace)))
        else:
            # Single-coil
            return current_kspace - \
                soft_dc(current_kspace) - \
                T.fft2(self.model(T.ifft2(current_kspace)))
    
class VariationalNetworkModel(MRIModel):
    def __init__(self, hparams):
        
        # Fix model loading issue
        if isinstance(hparams, dict):
            args = Namespace()
            for k,v in hparams.items():
                setattr(args, k, v)
            hparams = args
                
        super().__init__(hparams)
        self.challenge = hparams.challenge
        if self.challenge == 'multicoil':
            self.sens_net = SensitivityModel(
                hparams.sens_chans, hparams.sens_pools)
        self.cascades = nn.ModuleList([
            VarNetBlock(NormUnet(hparams.chans, hparams.pools))
            for _ in range(hparams.num_cascades)
        ])
        self.ssim_loss = SSIM()

        self.train_resolution = hparams.train_resolution
        self.num_epochs = hparams.num_epochs
        print(hparams)

    def forward(self, masked_kspace, mask):
        if self.challenge == 'multicoil':
            sens_maps = self.sens_net(masked_kspace, mask)
        else: 
            sens_maps = None
        kspace_pred = masked_kspace.clone()
        for cascade in self.cascades:
            kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)
        return T.root_sum_of_squares(T.complex_abs(T.ifft2(kspace_pred)), dim=1)

    def training_step(self, batch, batch_idx):
        masked_kspace, mask, target, fname, _, max_value = batch
        output = self.forward(masked_kspace, mask)
        target, output = T.center_crop_to_smallest(target, output)
        ssim_loss = self.ssim_loss(output.unsqueeze(
            1), target.unsqueeze(1), data_range=max_value)
        return {'loss': ssim_loss, 'log': {'train_loss': ssim_loss.item()}}
    
    def on_epoch_start(self):
        if self.hparams.aug_on:
            p = DA.schedule_p(self.hparams, self.current_epoch)
            print('augmentation p: ', p)
            return {'log':{'p': p}}
        else:
            return {'log':{'p': 0.0}}

    def validation_step(self, batch, batch_idx):
        masked_kspace, mask, target, fname, slice_num, max_value = batch

        output = self.forward(masked_kspace, mask)
        target, output = T.center_crop_to_smallest(target, output)

        # hash strings to int so pytorch can concat them
        fnumber = torch.zeros(len(fname), dtype=torch.long, device=output.device)
        for i, fn in enumerate(fname):
            fnumber[i] = (
                int(hashlib.sha256(fn.encode("utf-8")).hexdigest(), 16) % 10 ** 12
            )

        return {
            "fname": fnumber,
            "slice": slice_num,
            "output": output,
            "target": target,
            "val_loss": self.ssim_loss(
                output.unsqueeze(1), target.unsqueeze(1), data_range=max_value
            ),
        }


    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optim = torch.optim.Adam(
            self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optim, self.hparams.lr_step_size, self.hparams.lr_gamma)
        return [optim], [scheduler]

    def train_data_transform(self):
        current_epoch_fn = lambda : self.current_epoch
        return DA.create_data_transform(self.hparams, mode='train', current_epoch_fn=current_epoch_fn)

    def val_data_transform(self):
        return DA.create_data_transform(self.hparams, mode='val')

    def test_data_transform(self):
        return DA.create_data_transform(self.hparams, mode='test')
 

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--num-cascades', type=int, default=12, 
                            help='Number of U-Net channels')
        parser.add_argument('--pools', type=int, default=4,
                            help='Number of U-Net pooling layers')
        parser.add_argument('--chans', type=int, default=18,
                            help='Number of U-Net channels')
        parser.add_argument('--sens-pools', type=int, default=4,
                            help='Number of U-Net pooling layers')
        parser.add_argument('--sens-chans', type=int,default=8, 
                            help='Number of U-Net channels')
        parser.add_argument('--batch-size', default=1,
                            type=int, help='Mini batch size')
        parser.add_argument('--lr', type=float, default=0.0003, 
                            help='Learning rate')
        parser.add_argument('--lr-step-size', type=int, default=40,
                            help='Period of learning rate decay')
        parser.add_argument('--lr-gamma', type=float, default=0.1,
                            help='Multiplicative factor of learning rate decay')
        parser.add_argument('--weight-decay', type=float, default=0.,
                            help='Strength of weight decay regularization')
        parser.add_argument('--contrast-type', type=str,
                            help='Contrast type of MRI images used, options: both, nfs, fs', default='both')
        parser.add_argument('--train-resolution', default=None, nargs='+', type=int, 
                            help='Resolution of input images during training')
        parser.add_argument('--val-sample-rate', type=float,default=1.0,
                            help='Sampling rate for validation set.')

        return parser


def run(args):
    cudnn.benchmark = True
    cudnn.enabled = True
    
    if args.load_config is not None:
        args = load_args(args)
 
    if args.mode == 'train':  
        model = VariationalNetworkModel(args)
        trainer = create_trainer(args)
        print('Network parameters: ', num_param(model))
        trainer.fit(model)
    else:  # Evaluating model
        assert args.checkpoint is not None
        eval_checkpoint = str(args.checkpoint)
        model = VariationalNetworkModel.load_from_checkpoint(str(args.checkpoint))
        model.hparams.gpus = [0]
        model.hparams.mode == 'eval'
        model.hparams.sample_rate = 1.
        args = model.hparams
        trainer = create_trainer(args)
        trainer.test(model = model, ckpt_path=eval_checkpoint)


def create_trainer(args):
    checkpoint_callback = ModelCheckpoint(monitor='ssim',
                                         save_last=True,
                                         save_top_k=1,
                                         mode='max',
                                         period=1
                                         )
    if args.resume:
        checkpoint = str(args.checkpoint)
    else:
        checkpoint = None
        
    if args.mode == 'train':
        lr_logger = LearningRateLogger()
        callbacks = [lr_logger]
    else:
        callbacks = []
    return Trainer(
        default_root_dir=args.exp_dir,
        max_epochs=args.num_epochs,
        check_val_every_n_epoch=args.check_val_every_n_epoch,
        gpus=args.gpus,
        num_nodes=args.nodes,
        weights_summary=None,
        distributed_backend='ddp',
        replace_sampler_ddp=False,
        progress_bar_refresh_rate=100,
        resume_from_checkpoint=checkpoint,
        checkpoint_callback=checkpoint_callback,
        callbacks=callbacks,
        deterministic=True,
    )

def num_param(net):
    s = sum([np.prod(list(p.size())) for p in net.parameters()]);
    return s


def main(args=None):
    parser = Args()
    parser.add_argument('--mode', choices=['train', 'eval'], default='train')
    parser.add_argument('--num-epochs', type=int, default=50,
                        help='Number of training epochs')
    parser.add_argument('--check-val-every-n-epoch', type=int, default=1,
                        help='Frequency of evaluating validation metrics.')
    parser.add_argument('--gpus', type=int, nargs='+', default=[0])
    parser.add_argument('--nodes', type=int, default=1)
    parser.add_argument('--exp-dir', type=pathlib.Path, default='experiments',
                        help='Path where model and results should be saved')
    parser.add_argument('--exp', type=str,
                        help='Name of the experiment', default='default')
    parser.add_argument('--checkpoint', type=pathlib.Path,
                        help='Path to pre-trained model. Use with --mode test or --resume')
    parser.add_argument('--resume', help='resume training from checkpoint', default=False, action='store_true')
    parser.add_argument('--load-config', type=pathlib.Path, default=None,
                        help='If given, experiment configuration will be loaded from the yaml file from here.')
    
    parser = VariationalNetworkModel.add_model_specific_args(parser)
    parser = DA.add_augmentation_specific_args(parser)
    if args is not None:
        parser.set_defaults(**args)
    args, _ = parser.parse_known_args()

    seed_everything(args.seed)
    run(args)


if __name__ == '__main__':
    main()
