import copy
import distutils
import os
import random
from argparse import ArgumentParser
from distutils.util import strtobool
from itertools import chain
from pprint import pprint
from typing import Any, List

import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import wandb
from pandas import DataFrame
from pytorch_lightning.metrics.functional import (
    accuracy,
    f1_score,
    precision,
    precision_recall,
    recall,
)
from sklearn.preprocessing import LabelEncoder
from torch import nn
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from wandb import Table
from zendo.learning.datasets import FixedRuleOnlineDataset, get_train_test_structures


class StructureClassifier(pl.LightningModule):
    def __init__(self, hparams: dict):
        super(StructureClassifier, self).__init__()
        # not the best model...
        self.hparams = hparams
        self.rule = hparams.rule

        self.structure_encoder = LabelEncoder()
        self.structure_encoder.fit(["a", "A", "b", "B", "."])

        self.train_structures = None
        self.val_structures = None

        self.piece_embedding = Embedding(
            num_embeddings=self.hparams.num_different_pieces,
            embedding_dim=self.hparams.embedding_dim,
        )
        self.linear1 = torch.nn.Linear(
            in_features=self.hparams.structure_dim * self.hparams.embedding_dim,
            out_features=32,
        )
        self.linear2 = torch.nn.Linear(in_features=32, out_features=10)
        self.linear3 = torch.nn.Linear(in_features=10, out_features=2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.piece_embedding(x)
        out = out.view(-1, self.hparams.structure_dim * self.hparams.embedding_dim)

        out = F.relu(self.linear1(out))
        out = F.relu(self.linear2(out))
        out = F.relu(self.linear3(out))

        return out

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y, reduction="sum")

        tensorboard_logs = {"train_loss": loss}

        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        out = self.forward(x)
        y_pred = torch.argmax(out, dim=-1)

        predictions_recap = []
        for i in range(x.shape[0]):
            structure = "".join(self.structure_encoder.inverse_transform(x[i, :]))
            predicted_label = y_pred[i].bool().item()
            groundtruth_label = y[i].bool().item()
            predictions_recap.append(
                {
                    "Structure": structure,
                    "Predicted Label": predicted_label,
                    "True Label": groundtruth_label,
                }
            )

        return {
            "val_loss": F.cross_entropy(out, y, reduction="sum"),
            "logits": out.softmax(dim=-1),
            "y_pred": y_pred,
            "y": y,
            "predictions_recap": predictions_recap,
        }

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()

        logits = torch.cat([x["logits"] for x in outputs]).view(-1, 2)
        y_pred = torch.cat([x["y_pred"] for x in outputs]).view(-1)
        y = torch.cat([x["y"] for x in outputs]).view(-1)

        predictions = list(chain(*(x["predictions_recap"] for x in outputs)))
        # with pd.option_context(
        #     "display.max_rows", None, "display.max_columns", None
        # ):  # more options can be specified also
        #     print(df)

        self.logger.experiment.log(
            {f"predictions": Table(dataframe=DataFrame(predictions))}
        )
        self.logger.experiment.log(
            {"roc": wandb.plots.ROC(y, logits, labels=["False", "True"])}
        )

        prec, rec = precision_recall(pred=y_pred, target=y, num_classes=2)
        metrics = {
            "accuracy": accuracy(pred=y_pred, target=y, num_classes=2,),
            "precision": prec,
            "recall": rec,
            "f1_score": f1_score(pred=y_pred, target=y, num_classes=2),
            "avg_val_loss": avg_loss,
        }
        return {
            "val_loss": avg_loss,
            "log": metrics,
            "progress_bar": {"precision": prec},
        }

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    def prepare_data(self) -> None:
        self.train_structures, self.val_structures = get_train_test_structures(
            self.hparams.dataset_name, test_size=self.hparams.test_size, random_state=0
        )

    def train_dataloader(self):
        # REQUIRED
        return DataLoader(
            FixedRuleOnlineDataset(
                self.rule,
                self.train_structures,
                structure_encoder=self.structure_encoder,
            ),
            batch_size=self.hparams.batch_size,
            shuffle=strtobool(self.hparams.train_shuffle),
            num_workers=8,
        )

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(
            FixedRuleOnlineDataset(
                self.rule, self.val_structures, structure_encoder=self.structure_encoder
            ),
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=8,
        )

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--rule", default="at_least 1 blue", type=str)

        parser.add_argument("--embedding_dim", default=256, type=int)
        parser.add_argument("--num_different_pieces", default=5, type=int)

        # training specific (for this model)
        parser.add_argument("--test_size", default=0.3, type=float)
        parser.add_argument("--learning_rate", default=0.0001, type=float)
        parser.add_argument("--batch_size", default=16, type=int)
        parser.add_argument("--train_shuffle", default="True", type=str)
        # parser.add_argument("--max_nb_epochs", default=15, type=int)

        return parser
