""" Module for performing the forward and backward pass through a deep neural network with a
    batch of data stored in a dictionary
"""
import argparse
import datetime
from collections import defaultdict
import os
import logging
import time
from typing import Dict, Optional
import yaml

from accelerate import Accelerator
from accelerate.tracking import WandBTracker
from google.cloud import secretmanager
import numpy as np
import torch
from torch import optim
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup

from text2graph.training.losses_and_metrics import MetricFactory
from text2graph.models.model_loader import init_and_load_models
from text2graph.data.data_loader import init_dataloader

logger = logging.getLogger(__name__)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'


def read_wandb_secret() -> Optional[str]:
    """ Read weights and biases secrets """
    wandb_secret = os.environ.get("WANDB_SECRET", "")
    if wandb_secret:
        logger.info("Found W&B secret in the environment, attempting to retrieve API key")
        wandb_secret_location = os.environ.get("WANDB_SECRET")
        secrets_client = secretmanager.SecretManagerServiceClient()
        client_response = secrets_client.access_secret_version(
            request={"name": wandb_secret_location}
        )
        api_key = client_response.payload.data.decode("UTF-8")
        os.environ["WANDB_API_KEY"] = api_key
    return wandb_secret


def make_run_logging_directory(
    logging_directory: str,
    run_name: str,
    accelerator: Accelerator
) -> str:
    """ Creates a local directory in which to save a model """
    date_str = datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d')
    if not os.path.isdir(logging_directory):
        os.mkdir(logging_directory)
    run_description = f"{date_str}_{run_name}"
    random_number = torch.from_numpy(
        np.array([np.random.choice(int(1e6))])
    ).float().to(accelerator.device).detach()
    accelerator.wait_for_everyone()
    random_number = str(int(accelerator.gather_for_metrics(random_number).sum().item())).zfill(7)
    run_log_dir = os.path.join(logging_directory,f"{run_description}_{random_number}/")
    if not os.path.isdir(run_log_dir) and accelerator.is_main_process:
        os.mkdir(run_log_dir)
    accelerator.wait_for_everyone()
    return run_log_dir


class Trainer():
    """ A class for training a model using batch data loading """
    def __init__(self, config: str):
        """ Inits a Trainer instance """
        with open(config, "r", encoding='utf-8') as ymlfile:
            cfg = yaml.safe_load(ymlfile)
        self.accelerator = Accelerator()
        self.train_dict = cfg['training_params']
        self.checkpointing_value = np.inf
        self.steps = 0
        self.tracking_dict = {
            'train': defaultdict(list),
            'val': defaultdict(list),
            'test': defaultdict(list)
        }
        model_dir = make_run_logging_directory(
            logging_directory=cfg['logging_params']['logging_dir'],
            run_name=cfg['logging_params']['run_name'],
            accelerator=self.accelerator
        )
        model_dict = init_and_load_models(
            info_flow=cfg['info_flow'],
            saving_directory=model_dir,
            accelerator=self.accelerator
        )
        self.accelerator.print("\nFinished Initialization and Loading")
        if self.accelerator.is_main_process:
            with open(os.path.join(model_dir, "metadata.yaml"), 'w', encoding='utf-8') as ymlfile2:
                yaml.dump(cfg, ymlfile2)
        assert len(model_dict) == 1, "Insufficient infrastructure for training muliple LLMs"
        self.model_name = list(model_dict.keys())[0]
        self.model = model_dict[self.model_name]
        self.model = self.model.to(self.accelerator.device)
        self.model_info = cfg['info_flow'][self.model_name]
        train_dataloader = init_dataloader(
            split_name='train',
            dataset_config=cfg['dataset_config'],
            collate_function=self.model.get_collate_fn()
        )
        val_dataloader = init_dataloader(
            split_name='val',
            dataset_config=cfg['dataset_config'],
            collate_function=self.model.get_collate_fn()
        )
        test_dataloader = init_dataloader(
            split_name='test',
            dataset_config=cfg['dataset_config'],
            collate_function=self.model.get_collate_fn(),
            shuffle=False
        )
        self.eval_function = train_dataloader.dataset.calculate_metrics_batch
        self.metric_names = train_dataloader.dataset.metric_names()
        num_params = sum(p.numel() for p in self.model.parameters())
        self.accelerator.print(f"\nTraining {self.model_name} with {num_params} parameters")
        optimizer = optim.Adam(
            params=self.model.parameters(),
            lr=self.train_dict['lrn_rate'],
            betas=(self.train_dict['beta1'], self.train_dict['beta2']),
            weight_decay = self.train_dict['regularization_weight']
        )
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=0,
            num_training_steps=(
                len(train_dataloader.dataset)
                * self.train_dict['max_training_epochs']
            ),
        )
        (
            self.train_dataloader,
            self.val_dataloader,
            self.test_dataloader,
            self.model,
            self.optimizer,
            self.lr_scheduler
        ) = self.accelerator.prepare(
            train_dataloader,
            val_dataloader,
            test_dataloader,
            model_dict[self.model_name],
            optimizer,
            lr_scheduler
        )
        if self.accelerator.is_main_process:
            read_wandb_secret()
            self.tracker = WandBTracker(
                run_name=f"text2graph-{self.train_dataloader.dataset.name}",
                name=model_dict[self.model_name].metadata['model_dir'].split("/")[-2],
                config=cfg
            )

    def train(self) -> None:
        """ Trains a deep learning model using gradient descent """
        i_epoch = 0
        prev_time = time.time()
        for i_epoch in range(self.train_dict['max_training_epochs']):
            current_time = time.time()
            if i_epoch != 0:
                self.accelerator.print(f"Epoch took {current_time - prev_time} seconds")
                prev_time = time.time()
            self.accelerator.print(f'Training epoch #{i_epoch}')
            self.process_dataset(dataset_type="train", epoch=i_epoch)
            self.tracking_dict["train"].clear()
            with torch.no_grad():
                self.process_dataset(dataset_type="val",epoch=i_epoch)
                self.tracking_dict["val"].clear()
        self.accelerator.wait_for_everyone()
        model = self.accelerator.unwrap_model(self.model)
        model.load_parameters(model.metadata['model_dir'], self.accelerator)
        self.model = self.accelerator.prepare_model(model)
        with torch.no_grad():
            self.process_dataset(dataset_type="test", epoch=i_epoch)
        if self.accelerator.is_main_process:
            self.tracker.finish()

    def process_dataset(self, dataset_type: str, epoch: int) -> None:
        """
            Loops through a dataset either training or evaluating models on the dataset.
            As the code loops through the dataset it performs dynamic batching of graphs creating
            batches of a maximum size to deal with the fact that graphs in our data set vary by many
            orders of magnitude in size, so loading one graph at a time will lead to batches of
            size varying in size by orders of magnitude
        """
        assert dataset_type in ['train', 'val', 'test'], 'unknown data split name'
        if dataset_type == 'train':
            dataloader = self.train_dataloader
            self.model.train()
        elif dataset_type == 'val':
            dataloader = self.val_dataloader
            self.model.eval()
        else:
            dataloader = self.test_dataloader
            self.model.eval()
        self.steps = 0
        with tqdm(
            dataloader,
            unit=" batch",
            disable=not self.accelerator.is_main_process
        ) as tepoch:
            for sample_batched in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                self.process_batch(sample_batched, dataset_type=dataset_type)
                if self.accelerator.is_main_process:
                    metrics = {
                        key: np.round(np.nanmean(value), 3)
                        for key, value in self.tracking_dict[dataset_type].items()
                    }
                    tepoch.set_postfix(**metrics)
        metrics = {
            key: torch.from_numpy(np.nanmean(value).reshape(-1)).float().to(
                self.accelerator.device
            ).detach()
            for key, value in self.tracking_dict[dataset_type].items()
        }
        self.accelerator.wait_for_everyone()
        metrics = self.accelerator.gather_for_metrics(metrics)
        if self.accelerator.is_main_process:
            if dataset_type == 'val':
                metric_name = self.train_dict['checkpointing_metric']
                checkpointing_metric = metrics[metric_name].mean().item()
                checkpointing_metric *= self.train_dict['checkpointing_metric_sign']
                if checkpointing_metric < self.checkpointing_value:
                    self.checkpointing_value = checkpointing_metric
                    model = self.accelerator.unwrap_model(self.model)
                    model.save(self.accelerator)
            self.tracker.log(
                {f"{dataset_type}_{key}": value.mean().item() for key, value in metrics.items()}
            )

    def process_batch(self, sample_batched: Dict[str, torch.Tensor], dataset_type: str) -> None:
        """ A forward pass through the network, this method collects
            the inputs for each model in the attribute model_inputs
            and stores the outputs of each model in the attribute
            model_outputs. Then the network calculates the loss values
            and evaluations metric values for the batch.
        """
        model_outputs = {
            'dataset' : {
                key: sample_batched[key].to(self.accelerator.device)
                if isinstance(sample_batched[key], torch.Tensor) else sample_batched[key]
                for key in sample_batched.keys()
            }
        }
        model_outputs['dataset']['eval_function'] = self.eval_function
        model_outputs['dataset']['accelerator'] = self.accelerator
        for name in self.metric_names:
            model_outputs['dataset'][name] = name
        loss = torch.zeros(1).float().to(self.accelerator.device)
        model_outputs[self.model_name] = self.model(model_outputs['dataset'])
        if (
            'gen' in self.model_info
            and (
                self.steps % self.model_info['gen']['cadence'] == 0
                or dataset_type != 'train'
            )
        ):
            model_outputs[self.model_name]['ground_truths'] = self.model.inputs2graph(
                model_outputs['dataset'],
                self.model.tokenizer,
                file_paths=model_outputs['dataset'].get('file_path', None)
            )
            for call_name, call_dict in self.model_info['gen']['calls'].items():
                generated_graphs, generated_text = self.model.generate_graph(
                    **{**model_outputs['dataset'], **call_dict}
                )
                model_outputs[self.model_name][f"{call_name}_graphs_gen"] = generated_graphs
                model_outputs[self.model_name][f"{call_name}_text_gen"] = generated_text
        for eval_name, eval_args in self.model_info['evaluations'].items():
            metric_name = eval_name.split(',')[0]
            input_list = [
                model_outputs[i_source][i_name]
                for i_name, i_source in eval_args['inputs'].items()
                if i_source in model_outputs and i_name in model_outputs[i_source]
            ]
            if len(input_list) != len(eval_args['inputs']):
                continue
            metric_value = MetricFactory[metric_name].value(*tuple(input_list)).mean()
            if eval_args['is_loss']:
                loss += metric_value * eval_args.get('weight', 1.0)
            self.tracking_dict[dataset_type][eval_args['logging_name']].append(metric_value.item())
        self.tracking_dict[dataset_type]['aggregate_loss'].append(loss.item())
        if dataset_type == 'train':
            self.accelerator.backward(loss)
            self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.lr_scheduler.step()
            self.optimizer.zero_grad()
        self.steps += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Training text graph generators')
    parser.add_argument('--config', type=str, required=True)
    args = parser.parse_args()
    trainer = Trainer(config=args.config)
    trainer.train()
    if trainer.accelerator.is_main_process:
        training_path = os.path.dirname(os.path.realpath(__file__))
        tp_split = [directory for directory in training_path.split("/") if directory != ""]
        evaluation_path = "/" + "/".join(tp_split[:-1]) + "/evaluation/"
        with open(f"{evaluation_path}/model_path.txt", 'w', encoding='utf-8') as text_file:
            text_file.write(trainer.model.metadata['model_dir'])
