import argparse
from pathlib import Path
from typing import Any

import torch
from easy_tpp.config_factory import Config
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
from easy_tpp.model.torch_model.torch_intensity_free import IntensityFree
from easy_tpp.runner import Runner
from omegaconf import OmegaConf
from torch import Tensor, nn


class RenewalGRU(nn.Module):
    """
    Emulates the GRU layer in the IntensityFree model.
    Doesn't include history beyond the last time step.

    Args:
        input_size (int): size of input tensor
        hidden_size (int): size of hidden state tensor
        num_layers (int): number of GRU layers
    """
    def __init__(self, input_size: int, hidden_size: int, num_layers: int, **kwargs) -> None:
        super().__init__()
        layers = [
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
        ]
        for _ in range(num_layers):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x: Tensor) -> tuple[Tensor, None]:
        """
        Args:
            x (Tensor): input tensor of shape (batch_size, seq_len, input_size)

        Returns:
            y: output tensor of shape (batch_size, seq_len, hidden_size)
            None: no hidden state is returned
        """
        return self.net(x), None


class RenewalIntensityFree(TorchBaseModel):
    """
    Emulates the IntensityFree model.
    Doesn't include history beyond the last time step.
    """
    def __init__(self, model_config: OmegaConf) -> None:
        super().__init__(model_config)
        self.module = IntensityFree(model_config)
        self.module.layer_rnn = RenewalGRU(
            input_size=self.module.num_features,
            hidden_size=self.module.hidden_size,
            num_layers=self.module.hidden_size,
        )

    def forward(self, *args, **kwargs) -> Any:
        return self.module.forward(*args, **kwargs)

    def loglike_loss(self, *args, **kwargs) -> Any:
        return self.module.loglike_loss(*args, **kwargs)

    def predict_one_step_at_every_event(self, *args, **kwargs) -> Any:
        return self.module.predict_one_step_at_every_event(*args, **kwargs)


class NoHistGRU(nn.Module):
    def __init__(self, hidden_size: int) -> None:
        super().__init__()
        self.hidden_size = hidden_size

    def forward(self, x: Tensor) -> tuple[Tensor, None]:
        context = torch.ones(x.shape[0], x.shape[1], self.hidden_size).to(x)
        return context, None


class NoHistIntensityFree(TorchBaseModel):
    """
    Emulates the IntensityFree model.
    Doesn't include history beyond the last time step.
    """
    def __init__(self, model_config: OmegaConf) -> None:
        super().__init__(model_config)
        self.module = IntensityFree(model_config)
        self.module.layer_rnn = NoHistGRU(hidden_size=self.module.hidden_size)

    def forward(self, *args, **kwargs) -> Any:
        return self.module.forward(*args, **kwargs)

    def loglike_loss(self, *args, **kwargs) -> Any:
        return self.module.loglike_loss(*args, **kwargs)

    def predict_one_step_at_every_event(self, *args, **kwargs) -> Any:
        return self.module.predict_one_step_at_every_event(*args, **kwargs)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config-dir', type=str)
    parser.add_argument('--experiment-id', type=str)
    parser.add_argument('--data', type=str)
    parser.add_argument('--model', type=str, choices=['NoHistIntensityFree', 'RenewalIntensityFree', 'IntensityFree'])
    parser.add_argument('--save-dir', type=str)
    parser.add_argument('--gpu', type=int)
    parser.add_argument('--seed', type=int)
    args = parser.parse_args()

    config = OmegaConf.load(args.config_dir)

    # Set base config from command line arguments
    exp_config = getattr(config, args.experiment_id)
    save_dir = Path(args.save_dir).resolve() / f'{args.model}_{args.data}_{args.seed}'
    setattr(exp_config.base_config, 'base_dir', str(save_dir))
    setattr(exp_config.base_config, 'model_id', args.model)
    setattr(exp_config.base_config, 'dataset_id', args.data)
    setattr(exp_config.model_config, 'gpu', args.gpu)
    setattr(exp_config.trainer_config, 'gpu', args.gpu)
    setattr(exp_config.trainer_config, 'seed', args.seed)

    # Finalize config
    config_cls = Config.by_name(config.get('pipeline_config_id').lower())
    config = config_cls.parse_from_yaml_config(config, experiment_id=args.experiment_id)

    # Set pretrained model path for evaluation
    if config.base_config.stage != 'train':
        model_paths = list(Path(config.base_config.base_dir).resolve().glob('**/models/saved_model'))
        if len(model_paths) == 1:
            config.model_config.pretrained_model_dir = str(model_paths[0])
        else:
            raise ValueError(f'Found {len(model_paths)} saved models. Expected 1.')

    model_runner = Runner.build_from_config(config)
    model_runner.run()
