import argparse
import json
import logging
import os
import shutil
from time import time
import uuid

import importlib
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from pydantic.dataclasses import dataclass

from pathlib import Path

import wandb
from pythae.data.preprocessors import DataProcessor
from pythae.models import AutoModel
from pythae.config import BaseConfig
from pythae.trainers import BaseTrainer, BaseTrainerConfig
from pythae.models.base.base_utils import ModelOutput


logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)

PATH = os.path.dirname(os.path.abspath(__file__))

ap = argparse.ArgumentParser()

ap.add_argument(
    "--models_path",
    help="The path to a model to generate from",
    required=True,
)
ap.add_argument(
    "--n_runs",
    type=int,
    default=20
)
ap.add_argument(
    "--use_wandb",
    help="whether to log the metrics in wandb",
    action="store_true",
)
ap.add_argument(
    "--wandb_project",
    help="wandb project name",
    default="latent_dim_sensi_classifications",
)
ap.add_argument(
    "--wandb_entity",
    help="wandb entity name",
    default="benchmark_team",
)

args = ap.parse_args()

device = "cuda" if torch.cuda.is_available() else "cpu"


@dataclass
class SingleLayerClassifierConfig(BaseConfig):
    input_shape: int = 16
    n_classes: int = 10

class SingleLayerClassifier(nn.Module):
    def __init__(self, model_config: SingleLayerClassifierConfig) -> None:
        super().__init__()

        self.linear = nn.Linear(model_config.input_shape, model_config.n_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.model_name = 'SingleLayerClassifier'
        self.model_config = model_config

    def forward(self, inputs, **kwargs):

        x = inputs["data"].to(device)
        y = inputs['labels'].to(device)

        out = self.linear(x.reshape(x.shape[0], -1).to(device))

        loss = self.criterion(out, y.long())

        return ModelOutput(
            predictions=out,
            loss=loss
        )

    def update(self):
        pass

    def save(self, dir_path):
        pass


def main(args):

    os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    model_signature = os.listdir(args.models_path)[0]

    model_path = os.path.join(args.models_path, model_signature, "final_model")


    # reload the model
    trained_model = AutoModel.load_from_folder(model_path).to(device)
    trained_model.eval()

    logger.info(f"Successfully reloaded {trained_model.model_name.upper()} model !\n")

    train_data = None
    eval_data = None

    if trained_model.model_config.input_dim == (1, 28, 28):
        dataset = 'mnist'

    elif trained_model.model_config.input_dim == (3, 32, 32):
        dataset = 'cifar10'

    elif trained_model.model_config.input_dim == (3, 64, 64):
        dataset = 'celeba'

    try:
        logger.info(f"\nLoading {dataset} data...\n")
        train_data = (
                np.load(os.path.join(PATH, f"data/{dataset}", "train_data.npz"))[
                    "data"
                ]
                / 255.0
            )
        train_targets = (
            np.load(os.path.join(PATH, f"data/{dataset}", "train_labels.npz"))["targets"]
        )

        eval_data = (
            np.load(os.path.join(PATH, f"data/{dataset}", "eval_data.npz"))["data"]
            / 255.0
        )
        eval_targets = (
            np.load(os.path.join(PATH, f"data/{dataset}", "eval_labels.npz"))["targets"]
        )


        test_data = (
            np.load(os.path.join(PATH, f"data/{dataset}", "test_data.npz"))["data"]
            / 255.0
        )
        test_targets = (
            np.load(os.path.join(PATH, f"data/{dataset}", "test_labels.npz"))["targets"]
        )
        
        
    except Exception as e:
        raise FileNotFoundError(
            f"Unable to load the data from 'data/{dataset}' folder. Please check that both a "
            "'train_data.npz' and 'eval_data.npz' are present in the folder.\n Data must be "
            " under the key 'data', in the range [0-255] and shaped with channel in first "
            "position\n"
            f"Exception raised: {type(e)} with message: " + str(e)
        ) from e

    logger.info("Successfully loaded data !\n")
    logger.info("------------------------------------------------------------")
    logger.info("Dataset \t \t Shape \t \t \t Range")
    logger.info(
            f"{dataset.upper()} train data: \t {train_data.shape, train_targets.shape} \t [{train_data.min()}-{train_data.max()}] "
        )
    logger.info(
        f"{dataset.upper()} eval data: \t {eval_data.shape, eval_targets.shape} \t [{eval_data.min()}-{eval_data.max()}] "
    )
    logger.info(
        f"{dataset.upper()} test data: \t {test_data.shape, test_targets.shape} \t [{test_data.min()}-{test_data.max()}]"
    )
    logger.info("------------------------------------------------------------\n")

    dataset_type = (
        "DoubleBatchDataset"
        if trained_model.model_name == "FactorVAE"
        else "BaseDataset"
    )

    data_processor = DataProcessor()
    train_data = data_processor.process_data(train_data).to(device)
    train_dataset = data_processor.to_dataset(train_data, dataset_type=dataset_type)
    train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=False)

    z = []

    try:
        with torch.no_grad():
            for _, inputs in enumerate(train_loader):
                encoder_output = trained_model(inputs)
                z_ = encoder_output.z
                z.append(z_)

    except RuntimeError:
        for _, inputs in enumerate(train_loader):
            encoder_output = trained_model(inputs)
            z_ = encoder_output.z.detach()
            z.append(z_)

    train_data = torch.cat(z)
    train_dataset = data_processor.to_dataset(data=train_data, labels=torch.tensor(train_targets).type(torch.long))

    eval_dataset = None

    if eval_data is not None:

        assert (
            eval_data.max() >= 1 and eval_data.min() >= 0
        ), "Eval data must in the range [0-1]"

        eval_data = data_processor.process_data(eval_data).to(device)
        eval_dataset = data_processor.to_dataset(eval_data, dataset_type=dataset_type)
        eval_loader = DataLoader(
            dataset=eval_dataset, batch_size=100, shuffle=False
        )

        z = []
        try:
            with torch.no_grad():
                for _, inputs in enumerate(eval_loader):
                    encoder_output = trained_model(inputs)
                    z_ = encoder_output.z
                    z.append(z_)

        except RuntimeError:
            for _, inputs in enumerate(eval_loader):
                encoder_output = trained_model(inputs)
                z_ = encoder_output.z.detach()
                z.append(z_)

        eval_data = torch.cat(z)
        eval_dataset = data_processor.to_dataset(data=eval_data, labels=torch.tensor(eval_targets).type(torch.long))
        eval_loader = DataLoader(
            dataset=eval_dataset, batch_size=100, shuffle=False
        )

    if test_data is not None:

        assert (
            test_data.max() >= 1 and test_data.min() >= 0
        ), "Test data must in the range [0-1]"

        test_data = data_processor.process_data(test_data).to(device)
        test_dataset = data_processor.to_dataset(test_data, dataset_type=dataset_type)
        test_loader = DataLoader(
            dataset=test_dataset, batch_size=100, shuffle=False
        )

        z = []
        try:
            with torch.no_grad():
                for _, inputs in enumerate(test_loader):
                    encoder_output = trained_model(inputs)
                    z_ = encoder_output.z
                    z.append(z_)

        except RuntimeError:
            for _, inputs in enumerate(test_loader):
                encoder_output = trained_model(inputs)
                z_ = encoder_output.z.detach()
                z.append(z_)

        test_data = torch.cat(z)
        test_dataset = data_processor.to_dataset(data=test_data, labels=torch.tensor(test_targets).type(torch.long))
        test_loader = DataLoader(
            dataset=test_dataset, batch_size=100, shuffle=False
        )

    eval_acc = []
    test_acc = []

    for i in range(args.n_runs):

        mlp_model_config = SingleLayerClassifierConfig(input_shape=trained_model.latent_dim)
        mlp_model = SingleLayerClassifier(model_config=mlp_model_config)

        trainer = BaseTrainer(
            model=mlp_model,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            training_config=BaseTrainerConfig(num_epochs=100, seed=i, output_dir=os.path.join("dummy_output_dir", f"{int(os.environ['SLURM_ARRAY_TASK_ID'])}", str(uuid.uuid4().hex))),
        )

        trainer.train()

        trained_mlp_model = trainer._best_model
        trained_mlp_model.eval()

        shutil.rmtree(trainer.training_dir)

        correct = 0
        total = 0
        # eval accuracy
        with torch.no_grad():
            for inputs in eval_loader:
                labels = inputs['labels'].to(device)
                outputs = trained_mlp_model(inputs)
                _, predicted = torch.max(outputs.predictions, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        eval_acc.append(np.round(correct / total, 4))

        correct = 0
        total = 0
        # test accuracy
        with torch.no_grad():
            for inputs in test_loader:
                labels = inputs['labels'].to(device)
                outputs = trained_mlp_model(inputs)
                _, predicted = torch.max(outputs.predictions, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_acc.append(np.round(correct / total, 4))

    print("-------------------------------------")
    print(f"mean eval accuracy : {np.mean(eval_acc)}")
    print(f"std eval accuracy : {np.std(eval_acc)}")
    print("-------------------------------------")
    print(f"mean test accuracy : {np.mean(test_acc)}")
    print(f"std test accuracy : {np.std(test_acc)}")
    print("-------------------------------------")

    if args.use_wandb:
        
        if importlib.util.find_spec("wandb") is not None:
            
            wandb.init(project=args.wandb_project, entity=args.wandb_entity)
            wandb.config.update(
                {   
                    "n_runs": args.n_runs,
                    "model_path": model_path,
                    "model_config": trained_model.model_config.to_dict()
                }
            )

        else:
            raise ModuleNotFoundError(
                "`wandb` package must be installed. Run `pip install wandb`"
            )

        # logging some final images

        wandb.log(
            {
                "eval/mean_accuracy": np.mean(eval_acc),
                "eval/std_accuracy": np.std(eval_acc),
                "test/mean_accuracy": np.mean(test_acc),
                "test/std_accuracy": np.std(test_acc),
                })

if __name__ == "__main__":

    main(args)
