import random
import wandb

import torch
from torch_geometric.loader import DataLoader

import hienet._keys as KEY
from hienet.model_build import build_E3_equivariant_model
from hienet.scripts.processing_continue import processing_continue
from hienet.scripts.processing_dataset import processing_dataset
from hienet.hienet_logger import Logger
from hienet.train.trainer_lightning import LightningModel

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint

class LightningDataset(L.LightningDataModule):
    def __init__(self, config, working_dir):
        super(LightningDataset, self).__init__()
        self.working_dir = working_dir
        self.config = config
        self.train_list = None
        self.valid_list = None

    def setup(self, stage=None):
        if stage is None:
            self.train_list, self.valid_list, _ = processing_dataset(
                self.config, self.working_dir
            )
            if self.config[KEY.USE_FULL_TRAINING]:
                self.train_list = self.train_list + self.valid_list

    def train_dataloader(self):
        return DataLoader(
            self.train_list, batch_size=self.config[KEY.BATCH_SIZE], shuffle=self.config[KEY.TRAIN_SHUFFLE]
            )

    def val_dataloader(self):
        return DataLoader(
            self.valid_list, batch_size=self.config[KEY.BATCH_SIZE]
        )
    
    def test_dataloader(self):
        return None
    
    def teardown(self, stage=None):
        self.train_list.clear()
        self.valid_list.clear()

def init_loaders(config, working_dir):
    loaders = LightningDataset(config, working_dir)
    loaders.setup()
    return loaders


# TODO: E3_equivariant model assumed
def train(config, working_dir: str, experiment_name: str):
    """
    Main program flow
    """
    Logger().timer_start('total')
    seed = config[KEY.RANDOM_SEED]
    random.seed(seed)
    torch.manual_seed(seed)

    # config updated
    if config[KEY.CONTINUE][KEY.CHECKPOINT] is not False:
        state_dicts, start_epoch, init_csv = processing_continue(config)
    else:
        state_dicts, start_epoch, init_csv = None, 1, True

    # config updated
    # Note that continue and dataset cannot be seperated completely
    loaders = init_loaders(config, working_dir)

    Logger().write('\nModel building...\n')
    model = build_E3_equivariant_model(config)

    Logger().write('Model building was successful\n')

    if state_dicts is not None:
        model_L = LightningModel.load_from_checkpoint(
            config[KEY.CONTINUE][KEY.CHECKPOINT], model=model, 
            config=config, experiment_name=experiment_name, init_csv=init_csv)
    else:
        model_L = LightningModel(model, config, experiment_name, init_csv)

    Logger().print_model_info(model, config)
    # log_model_info(model, config)

    Logger().write('Trainer initialized, ready to training\n')
    Logger().bar()

    bestmodel_callback = ModelCheckpoint(
        save_top_k=1,
        monitor="val",
        mode="min",
        dirpath=working_dir,
        filename="bestmodel-{epoch}-{val:.3f}",
    )

    checkpoint_callback = ModelCheckpoint(
        save_top_k=20,
        monitor="epoch",
        mode="max",
        every_n_epochs=config[KEY.PER_EPOCH],
        dirpath=working_dir,
        filename="checkpoint-{epoch}-{val:.3f}",
    )

    n_gpus_per_node = torch.cuda.device_count()
    trainer_args = {'accelerator': 'gpu' if n_gpus_per_node>0 else 'cpu', 
            'max_epochs': config[KEY.EPOCH], 
            'max_steps': -1,
            'callbacks': [bestmodel_callback, checkpoint_callback],
            'default_root_dir': working_dir,
            }
    if n_gpus_per_node > 0: 
        if config[KEY.IS_DDP]:
            trainer_args['devices'] = n_gpus_per_node
        else:
            trainer_args['devices'] = 1
        torch.set_float32_matmul_precision('high')
    if config[KEY.IS_DDP]:
        trainer_args.update(
            {
                'num_nodes': config[KEY.N_NODES],
                'strategy': 'ddp',
            }
        )

    trainer = L.Trainer(**trainer_args)

    fit_args = {
        'model': model_L,
        'datamodule': loaders,
    }
    if state_dicts is not None:
        fit_args['ckpt_path'] = config[KEY.CONTINUE][KEY.CHECKPOINT]

    trainer.fit(**fit_args)

    try:
        wandb.finish()
    except Exception as e:
        print('\n Wandb was not initialized on this node, but training was successfully finished')
    
    Logger().timer_end('total', message='Total wall time')
