import random

import torch
import numpy as np
import pytorch_lightning as pl

from file_io import Files
from .trainer import ICDCodingWrapper
from ..callbacks import LogProgressBar
from .configs_full import Config


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def run():
    config = Config()
    print(config)

    set_seed(config.seed)
    Files.init_path(config.io)

    model = ICDCodingWrapper(config)
    print(model)

    # pretrained_param = torch.load('../saved/label_pretrain_5.pt', map_location='cpu')
    # config = Config()
    # print(config)

    # set_seed(config.seed)
    # Files.init_path(config.io)

    # model = ICDCodingWrapper(config)
    # print(model)

    # model.model.load_state_dict(pretrained_param)
    # pretrained_param = torch.load('../saved/epoch2.pt', map_location='cpu')
    # config = Config()
    # print(config)

    # set_seed(config.seed)
    # Files.init_path(config.io)

    # model = ICDCodingWrapper(config)
    # print(model)

    # embedding_param = {'word_embedding.weight': pretrained_param['model.encoder.embedding.word_embedding.weight']}
    # model.model.encoder.embedding.load_state_dict(embedding_param)

    trainer = pl.Trainer(
        accelerator=config.train.accelerator,
        devices=config.train.devices,
        max_epochs=config.train.epochs,
        num_sanity_val_steps=0,
        accumulate_grad_batches=config.train.accumulate_grad_batches,
        enable_checkpointing=False,
        logger=False,
        callbacks=[
            # GradientClipCallback(clip_val=config.train.gradient_clip_val),
            LogProgressBar(),
            # SaveBestCallback(config.train.monitor_metric, Files.coding_checkpoints, config.name, monitor_type='test')
        ]
    )
    trainer.fit(model)
