"""Script for training the OrganDet project adapted from https://github.com/bwittmann/transoar."""

import argparse
import os,sys
import random
from pathlib import Path
import warnings
warnings.filterwarnings("ignore", message="TypedStorage")
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
print("append to path & chdir:", base_dir)
os.chdir(base_dir)
sys.path.append(base_dir)
import numpy as np
import torch
import monai, re

from organ_detr.trainer import Trainer
from organ_detr.data.dataloader import get_loader
from organ_detr.utils.io import get_config, write_json, get_meta_data
from organ_detr.models.organdetr_net import OrganDetrNet
from organ_detr.models.build import build_criterion


def get_last_ckpt(filepath):
    # check if last checkpoint avail
    keyword = 'model_last.pt'
    ckpt_file = f"{filepath}/{keyword}"
    return ckpt_file

def match(n, keywords):
    out = False
    for b in keywords:
        if b in n:
            out = True
            break
    return out

def train(config, args):
    os.environ["CUDA_VISIBLE_DEVICES"] = config['device'][-1]
    device = 'cuda'

    # Build necessary components
    train_loader = get_loader(config, 'train')

    if config['overfit']:
        val_loader = get_loader(config, 'train')
    else:
        val_loader = get_loader(config, 'val')

    model = OrganDetrNet(config).to(device=device)

    if config.get('hybrid_dense_matching', False):
        criterion, dense_hybrid_criterion = build_criterion(config)
        criterion = criterion.to(device=device)
        dense_hybrid_criterion = dense_hybrid_criterion.to(device=device)
    else:
        criterion = build_criterion(config).to(device=device)
        dense_hybrid_criterion = None
    

    # Analysis of model parameter distribution
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    num_backbone_params = sum(p.numel() for n, p in model.named_parameters() if p.requires_grad and match(n, ['backbone', 'input_proj', 'skip']))
    num_neck_params = sum(p.numel() for n, p in model.named_parameters() if p.requires_grad and match(n, ['neck', 'query']))
    num_head_params = sum(p.numel() for n, p in model.named_parameters() if p.requires_grad and match(n, ['head']))

    param_dicts = [
        {
            'params': [
                p for n, p in model.named_parameters() if not match(n, ['backbone', 'reference_points', 'sampling_offsets']) and p.requires_grad
            ],
            'lr': float(config['lr'])
        },
        {
            'params': [p for n, p in model.named_parameters() if match(n, ['backbone']) and p.requires_grad],
            'lr': float(config['lr_backbone'])
        } 
    ]

    # Append additional param dict for def detr
    if sum([match(n, ['reference_points', 'sampling_offsets']) for n, _ in model.named_parameters()]) > 0:
        param_dicts.append(
            {
                "params": [
                    p for n, p in model.named_parameters() if match(n, ['reference_points', 'sampling_offsets']) and p.requires_grad
                ],
                'lr': float(config['lr']) * config['lr_linear_proj_mult']
            }
        )

    optim = torch.optim.AdamW(
        param_dicts, lr=float(config['lr_backbone']), weight_decay=float(config['weight_decay'])
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optim, config['lr_drop'])


    # Init logging
    path_to_run = Path(os.getcwd()) / 'runs' / config['experiment_name']
    path_to_run.mkdir(exist_ok=True)


    # Load checkpoint if applicable
    if config.get('resume', False) or args.resume:
        ckpt_file = get_last_ckpt(path_to_run)
        print(f'[+] loading ckpt {ckpt_file} ...')
        checkpoint = torch.load(Path(ckpt_file))

        checkpoint['scheduler_state_dict']['step_size'] = config['lr_drop']

        # Unpack and load content
        model.load_state_dict(checkpoint['model_state_dict'])
        optim.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        epoch = checkpoint['epoch']
        metric_start_val = checkpoint['metric_max_val']
    else:
        epoch = 0
        metric_start_val = 0


    # log num_params
    num_params_dict ={'num_params': num_params,
                      'num_backbone_params': num_backbone_params,
                      'num_neck_params': num_neck_params,
                      'num_head_params': num_head_params
                      }
    config.update(num_params_dict)
    # Get meta data and write config to run
    try:
        config.update(get_meta_data())
    except:
        pass

    write_json(config, path_to_run / 'config.json')

    # Build trainer and start training
    trainer = Trainer(
        train_loader, val_loader, model, criterion, optim, scheduler, device, config, 
        path_to_run, epoch, metric_start_val, dense_hybrid_criterion
    )
    trainer.run()
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Add minimal amount of args (most args should be set in config files)
    parser.add_argument("--config", type=str, required=True, help="Config to use for training located in /config.")
    parser.add_argument("--resume", action='store_true', help="Auto-loads model_last.pt.")
    args = parser.parse_args()

    # Get relevant configs
    config = get_config(args.config)

    # To get reproducable results
    torch.manual_seed(config['seed'])
    np.random.seed(config['seed'])
    monai.utils.set_determinism(seed=config['seed'])
    random.seed(config['seed'])

    torch.backends.cudnn.benchmark = False  # performance vs. reproducibility
    torch.backends.cudnn.deterministic = True

    train(config, args)
