import os
import argparse
import torch


from deepinv.loss.regularisers import JacobianSpectralNorm  # fixed needed for MM

from denflow.utils import load_cfg_from_cfg_file, merge_cfg_from_list
from denflow.dataloaders import DataLoaders
from denflow.train_denoisers import GENERAL_DENOISER
from denflow.optimal_denoiser import OPTIMAL_DENOISER
from denflow.ten_models_denoiser import TEN_MODELS_DENOISER
from denflow.concatenated_denoiser import CONCATENATED_DENOISER
from denflow.methods.basic_psnr import BASIC_PSNR
from denflow.sampling_methods.basic_sampler import BASIC_SAMPLER
from denflow.sampling_methods.two_models_sampler import TWO_MODELS_SAMPLER
from denflow.sampling_methods.noise_perturb_sampler import NOISE_PERTURB_SAMPLER
from denflow.methods.noise_perturb_psnr import NOISE_PERTURB_PSNR
from denflow.methods.pnp_flow_denoiser import PNP_FLOW_DENOISER
from denflow.investigation.investigate_approx import INVESTIGATE_APPROX
from denflow.investigation.investigate_lip import INVESTIGATE_LIP
from denflow.utils import define_model, load_model_runid, build_degradation_and_noise, set_seed

# torch.cuda.empty_cache()
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description='Main')
    cfg = load_cfg_from_cfg_file('./' + 'config/main_config.yaml')
    parser.add_argument('--opts', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()
    if args.opts is not None:
        cfg = merge_cfg_from_list(cfg, args.opts)

    dataset_config = cfg.root + \
        'config/dataset_config/{}.yaml'.format(
            cfg.dataset)
    cfg.update(load_cfg_from_cfg_file(dataset_config))

    method_config_file = cfg.root + \
        'config/method_config/{}.yaml'.format(
            cfg.method)
    cfg.update(load_cfg_from_cfg_file(method_config_file))

    if args.opts is not None:
        # override config with command line input
        cfg = merge_cfg_from_list(cfg, args.opts)

    # for all keys in the method config file, create a dictionary {key: value} in the cfg object cfg.dict_cfg_method
    method_cfg = load_cfg_from_cfg_file(method_config_file)
    cfg.dict_cfg_method = {}
    for key in method_cfg.keys():
        cfg.dict_cfg_method[key] = cfg[key]
    return cfg


def build_save_path(args):
    root = args.root
    if args.train_object == 'optimal_denoiser' or args.train_object == 'ten_models_denoiser':
        parts = [args.dataset, args.method, args.train_object, args.loss_type,
                 args.class_object]
    elif args.train_object == 'concatenated_denoiser':
        parts = [args.dataset, args.method, args.train_object, args.loss_type,
                 args.class_object,  *args.run_ids, args.model_fold]
    else:
        parts = [args.dataset, args.method, args.loss_type,
                 args.class_object, args.run_id, args.model_fold]

    if args.dim_image == 2:
        parts.insert(3, str(args.t_sep))
        return os.path.join(root, 'results_2D', *parts)

    if args.application == 'inverse_problem':
        if args.method == 'noise_perturb_psnr':
            base = 'noise_perturb_psnr'
        elif args.method == 'basic_psnr' or args.method == 'ten_models_psnr':
            base = 'results_psnr'
        else:
            base = 'results_ip'

        if args.method != 'basic_psnr' or args.method != 'ten_models_psnr' and args.method != 'noise_perturb_psnr':
            parts.insert(2, args.problem)

        return os.path.join(root, base, *parts)

    if args.application == 'sampling':
        if args.method == 'two_models_sampler':
            parts.insert(2, '[{}, {}]'.format(args.t_inf, args.t_sup))
            parts.insert(3, args.base_run_id)
            parts.insert(4, str(args.ema))
            return os.path.join(root, 'results_sampling_2models', *parts)
        elif args.method == 'noise_perturb_sampler':
            parts.insert(3, str(args.ema))
            return os.path.join(root, 'results_sampling_noise_perturb', *parts)
        if args.method == 'ten_models_sampler':
            parts.insert(3, args.position_FM)
            parts.insert(4, str(args.t_sep))
            parts.insert(5, str(args.ema))
            return os.path.join(root, 'results_sampling_2models', *parts)
        elif args.method == 'ten_models_sampler':
            return os.path.join(root, 'results_sampling', *parts)
        elif args.method == 'sampler_and_drunet':
            return os.path.join(root, 'results_sampling', *parts)
        else:
            return os.path.join(root, 'results_sampling', *parts)
    if args.application == 'investigation':
        return os.path.join(root, 'results_investigations', *parts)

    raise ValueError(
        "application must be either 'inverse_problem' or 'sampling'")


def build_method(generative_method, device, args):
    if args.application == 'sampling':
        if args.method == 'basic_sampler':
            return BASIC_SAMPLER(generative_method, device, args)
        if args.method == 'two_models_sampler':
            return TWO_MODELS_SAMPLER(generative_method, device, args)
        if args.method == 'noise_perturb_sampler':
            return NOISE_PERTURB_SAMPLER(generative_method, device, args)
        raise ValueError(f"Unknown sampling method: {args.method}")

    if args.application == 'inverse_problem':
        if args.method == 'pnp_flow_denoiser':
            return PNP_FLOW_DENOISER(generative_method, device, args)
        if args.method == 'basic_psnr':
            return BASIC_PSNR(generative_method, device, args)
        if args.method == 'noise_perturb_psnr':
            return NOISE_PERTURB_PSNR(generative_method, device, args)
        raise ValueError(f"Unknown inverse_problem method: {args.method}")

    if args.application == "investigation":
        if args.method == "approx":
            return INVESTIGATE_APPROX(generative_method, device, args)
        elif args.method == "lip":
            return INVESTIGATE_LIP(generative_method, device, args)
        else:
            raise ValueError(f"Unknown inverse_problem method: {args.method}")

    raise ValueError(f"Unknown application: {args.application}")


def main():
    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device", device)
    set_seed(args.seed if hasattr(args, 'seed') else None)
    print('seed', args.seed)

    (model, state) = define_model(args)
    if args.train_object == "denoiser":
        generative_method = GENERAL_DENOISER(
            model, loss_denoising=args.loss_type, class_denoiser=args.class_object, device=device, args=args)
    elif args.train_object == "optimal_denoiser":
        generative_method = OPTIMAL_DENOISER(device=device, args=args)
    elif args.train_object == "ten_models_denoiser":
        generative_method = TEN_MODELS_DENOISER(
            model, loss_denoising=args.loss_type, class_denoiser=args.class_object, device=device, args=args)
    elif args.train_object == "concatenated_denoiser":
        generative_method = CONCATENATED_DENOISER(
            loss_denoising=args.loss_type, class_denoiser=args.class_object,  device=device, args=args)
    else:
        raise ValueError(
            "train_object must be either denoiser or velocity")

    if args.train:
        print('Training...')
        args.batch_size = args.batch_size_train
        data_loaders = DataLoaders(
            args.dataset, args.batch_size_train, args.batch_size_train, args).load_data()
        generative_method.train(data_loaders)
        print('Training done!')

    elif args.eval:
        print("args.train_object", args.train_object)
        if args.train_object == "optimal_denoiser":
            data_loaders = DataLoaders(
                args.dataset, 20000, args.batch_size_ip, args).load_data()
            generative_method.prepare_optimal_denoiser(data_loaders)
        elif args.train_object == "concatenated_denoiser":
            pass
        else:

            if len(args.model_fold) > 0:
                model = load_model_runid(
                    model, model_fold=args.model_fold, run_id=args.run_id, device=device)
            else:
                model = load_model_runid(
                    model, run_id=args.run_id, device=device)
            model.eval()
            generative_method.model = model

        degradation, sigma_noise = build_degradation_and_noise(args, device)
        print(
            f"Solving the {args.problem} inverse problem with the method {args.method}...")

        data_loaders = DataLoaders(
            args.dataset, args.batch_size_ip, args.batch_size_ip, args).load_data()

        args.save_path = build_save_path(args)
        os.makedirs(args.save_path, exist_ok=True)

        method = build_method(generative_method, device, args)
        set_seed(args.seed if hasattr(args, 'seed') else None)
        method.run_method(data_loaders, degradation, sigma_noise)


if __name__ == "__main__":
    main()
