import os
import sys
import time
import warnings
import argparse

import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from config import Config
from dataset import ShuffleDataset
from environment_generator import EnvironmentDataset
from model.model_trainer import RewardModelTrainer
from model.model_validator import RewardModelValidator
from model.reward_model import RewardModel, load_all_saved_models
from model.minibatch_stats import RewardModelStats
from model.head import head_state_to_tensor, HeadModel
from model.model_interactor_callbacks import (
    TruncatedLossGraphCallback, LossGraphCallback, FilterGraphCallback, MonitorLossCallback)
from model.model_pipeline import RewardModelPipeline
from driving_gridworld.actions import ACTIONS

def seed_everything(random_state):
    """
    Seeds all RNG.
    """
    np.random.seed(random_state)
    torch.manual_seed(random_state)

def train_model(
        model: nn.Module,
        config: Config,
        dataset: ShuffleDataset,
        valid_dataset: ShuffleDataset,
        model_name: str,
) -> nn.Module:
    """
    Runs the full training / validation pipeline on a given model.
    :param model: The model to train.
    :param config: The configuration which specifies the number of epochs,
    the model directory and the figs directory.
    :param model_name: The id tag of the model
    """
    fig, axs = plt.subplots(4, 1)
    fig.set_size_inches(20, 20)

    TruncatedTrainLoss = TruncatedLossGraphCallback(
        "Train loss over n epochs (Truncated)",
        "n epochs",
        "Loss",
        axs[0],
        model_name,
        config.loss_graph_truncation,
    )

    def strange_example_filter(stats: RewardModelStats):
        board, speed, board_after, speed_after = stats.state
        obstacle_layer, car_layer = 3, 2
        x, y = np.where(board[car_layer].numpy().astype(int))
        x_after, y_after = np.where(board_after[car_layer].numpy().astype(int))
        if x[0] == 2 and y[0] == 1 and speed == 2 and speed_after == 2 and y_after[0] == 0:
            if len(x) != 1:
                assert False
            x, y = np.where(board[obstacle_layer].numpy().astype(int))
            if len(x) == 1 and x[0] == 2 and y[0] == 0:
                assert stats.reward_array == -7
                return True
        return False


    StrangeExampleTrainOutput = FilterGraphCallback(
        strange_example_filter,
        "Strange Example Output",
        "n epochs",
        "Output",
        axs[2],
        model_name,
        lambda stats: stats.output,
    )

    StrangeExampleTrainReward = FilterGraphCallback(
        strange_example_filter,
        "Strange Example Output",
        "n epochs",
        "Reward",
        axs[2],
        model_name,
        lambda stats: stats.reward_array,
    )

    ValidLoss = LossGraphCallback(
        "Valid loss over n epochs (Truncated)",
        "n epochs",
        "Loss",
        axs[1],
        model_name,
    )

    model_trainer = RewardModelTrainer(model,
                                       dataset,
                                       config.criterion,
                                       optim.Adam(
                                           model.parameters(),
                                           lr=0.0001,
                                           weight_decay=1e-5,
                                           ),
                                       callbacks=[
                                           TruncatedTrainLoss,
                                           StrangeExampleTrainOutput,
                                           StrangeExampleTrainReward,
                                       ])
    model_validator = RewardModelValidator(model,
                                           valid_dataset,
                                           config.criterion,
                                           [ValidLoss])

    model_dir = os.path.join(config.models_dir, f'{model_name}.pt')
    model_loss_dir = f'{config.model_figs_dir}/model_loss/{model_name}.png'

    pipeline = RewardModelPipeline(
        model_trainer,
        model_validator,
        validate_every=50,
        checkpoint_every=50,
        model_name=model_name,
    )
    if config.use_checkpoints:
        pipeline.load_checkpoint()
    else:
        print('not using checkpoints')
    pipeline.n_epochs(config.n_epochs)
    pipeline.cleanup(model_dir)
    plt.savefig(model_loss_dir)
    plt.close(fig)
    return model

def obtain_models(config):
    env = EnvironmentDataset.obtain_train_env()

    valid_env = EnvironmentDataset.obtain_test_env()
    valid_dataset = valid_env.obtain_dataset(
        batchsize=config.batchsize,
    )
    for seed in range(config.min_seed, config.max_seed):
        print('training', seed)
        if not config.use_pretrained_models or (
                config.use_pretrained_models and \
                        not os.path.exists(
                            os.path.join(config.models_dir, f'{seed}.pt'))):
            seed_everything(seed)
            start_time = time.time()
            dataset = env.obtain_dataset(
                batchsize=config.batchsize,
                seed=seed,
            )

            """
            Rough dumb sanity check for the strange behavior I saw earlier...
            """
            obstacle_layer, car_layer = 3, 2
            for i in range(len(dataset.minibatch_dataset)):
                for j in range(len(dataset[i][0][0][0])):
                    x, y = np.where(dataset[i][0][0][0][j][car_layer].numpy().astype(int))
                    x_after, y_after = np.where(dataset[i][0][1][0][j][car_layer].numpy().astype(int))
                    spd, spd_after = dataset[i][0][0][1][j], dataset[i][0][1][1][j]
                    if x[0] == 2 and y[0] == 1 and spd == 2 and spd_after == 2 and y_after[0] == 0:
                        if len(x) != 1:
                            assert False
                        x, y = np.where(dataset[i][0][0][0][j][obstacle_layer].numpy().astype(int))
                        if len(x) == 1 and x[0] == 2 and y[0] == 0:
                            dataset_i, dataset_j = i, j
            assert dataset[dataset_i][1][dataset_j] == -7


            train_model(
                RewardModel().to(config.device),
                config,
                dataset,
                valid_dataset,
                seed,
            )
            print(f'took {time.time() - start_time} seconds to train this model')
        else:
            print('not training this model as it already exists')

if __name__ == "__main__":
    config = Config()

    parser = argparse.ArgumentParser(description='Set some experiment specific parameters.')
    parser.add_argument("--min-seed", type=int, default=None)
    parser.add_argument("--max-seed", type=int, default=None)
    parser.add_argument("--train-head", type=bool, default=False)
    args = parser.parse_args()

    if not (args.min_seed is None or args.max_seed is None):
        config.min_seed = args.min_seed
        config.max_seed = args.max_seed

    obtain_head_models(config) if args.train_head else obtain_models(config)
