import json
import os
from argparse import ArgumentParser

import numpy as np
import torch
from module import CIFAR10ExplainModule, CIFAR10PGDExplainModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from tqdm import tqdm

from dataset.attribute_data import AttributeCIFAR10DataModule
from dataset.data import CIFAR10Data
from parse_tree import Parse_Tree


class SaveFeaturesHook:
    def __init__(self):
        self.features = []
        self.enabled = False

    def __call__(self, module, input, output):
        if self.enabled:
            self.features.append(output.detach().view(output.size(0), -1))

    def reset(self):
        self.features = []

    def enable(self):
        self.enabled = True

    def disable(self):
        self.enabled = False


class SaveSpatialFeaturesHook(SaveFeaturesHook):

    def __call__(self, module, input, output):
        if self.enabled:
            self.features.append(output.detach())


class SaveElemtFeaturesHook(SaveFeaturesHook):

    def __call__(self, module, input, output):
        if self.enabled:
            self.features = output.detach().view(output.size(0), -1)


def test_classifier_with_tree(model, test_loader, hook, trees, device, output_file):
    # Set the model to evaluation mode
    model.eval()
    # Reset the hook
    hook.enable()

    path_dicts = []  # Initialize a list to store all path_dicts

    with torch.no_grad():
        # Iterate over each batch in the test loader
        for inputs, targets in tqdm(test_loader):
            hook.reset()
            # Make predictions using the model
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            # Iterate over each prediction and corresponding feature vector
            for p, feat, target in zip(predicted, hook.features[0], targets):
                # Find the path in the Parse_Tree for the current feature vector
                path = trees[p.item()].top_matches(feat)
                path_dict = path.to_dict()
                # Append the path_dict and target to the list
                path_dicts.append({'path_dict': path_dict, 'pred': p.item()})
    
    # Save the path_dicts list to a json file
    with open(output_file, "w") as f:
        json.dump(path_dicts, f)


def main(args):

    seed_everything(0)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    if args.logger == "wandb":
        logger = WandbLogger(name=args.classifier, project="cifar10")
    elif args.logger == "tensorboard":
        logger = TensorBoardLogger("cifar10", name=args.classifier)

    checkpoint = ModelCheckpoint(
        monitor="acc/val", mode="max", save_last=False)

    trainer = Trainer(
        fast_dev_run=bool(args.dev),
        logger=logger if not bool(args.dev + args.test_phase) else None,
        gpus=-1,
        deterministic=True,
        weights_summary=None,
        log_every_n_steps=1,
        max_epochs=args.max_epochs,
        checkpoint_callback=checkpoint,
        precision=args.precision,
    )

    # Create an instance of the model and register the hook
    model = CIFAR10ExplainModule(args)
    hook = SaveFeaturesHook()
    model.model.avgpool.register_forward_hook(hook)
    hook.enable()
    data = AttributeCIFAR10DataModule(args, prefix=None)
    test_data = CIFAR10Data(args)

    # Load pre-trained model weights
    if args.checkpoint is not None and os.path.exists(args.checkpoint):
        state_dict = args.checkpoint
        model.model.load_state_dict(torch.load(state_dict)['model'])
    else:
        state_dict = os.path.join(
            "cifar10_models", "state_dicts", args.classifier + ".pt"
        )
        model.model.load_state_dict(torch.load(state_dict))
    test_loader = data.test_dataloader()
    trainer.test(model, test_loader)

    with open(args.prompt_file) as file:
        prompt_dict = json.load(file)

    category_to_index = test_loader.dataset.category_to_index
    trees = [None for c in test_loader.dataset.categories]
    for k, v in prompt_dict.items():
        assert isinstance(v, dict), "prompt_dict must be dict"
        c_tree = Parse_Tree.from_dict(v)
        trees[category_to_index[k]] = c_tree

    # Concatenate the feature tensors of all batches
    all_features = torch.cat(hook.features)

    # Retrieve attributes and labels from the model
    attribute = np.array(model.attributes)
    labels = torch.cat(model.labels).detach().cpu().numpy()

    # Iterate over each category and assign the feature values to the corresponding Parse_Tree object
    for i, c in enumerate(test_loader.dataset.categories):
        # Create a mask for the current category
        c_mask = labels == i

        # Retrieve the attributes and feature tensor for the current category
        c_attr = attribute[c_mask]
        c_feat = all_features[c_mask]

        # Create a dictionary to store the feature values for each attribute
        values = dict()
        for attr in set(c_attr):
            # Retrieve the feature values for the current attribute
            values[attr] = c_feat[c_attr == attr]
        trees[category_to_index[c]].set_values(values)

    test_classifier_with_tree(model.model,
                              output_file=args.output_file,
                              test_loader=test_data.test_dataloader(),
                              hook=hook,
                              trees=trees,
                              device='cuda:0')


if __name__ == "__main__":
    parser = ArgumentParser()

    # PROGRAM level args
    parser.add_argument("--data_dir", type=str, default="data/cifar10")
    parser.add_argument("--support_data_dir", type=str, default="data/cifar10")
    parser.add_argument("--test_phase", action='store_true')
    parser.add_argument("--dev", action='store_true')
    parser.add_argument(
        "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"]
    )

    # TRAINER args
    parser.add_argument("--prompt_file", type=str,
                        default="cifar10_prompt.json")
    parser.add_argument("--output_file", type=str,
                        default="cifar10_test_tree_{}.json")
    parser.add_argument("--classifier", type=str, default="resnet18")
    parser.add_argument("--pretrained", action='store_true')

    parser.add_argument("--precision", type=int, default=32, choices=[16, 32])
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--max_epochs", type=int, default=100)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--gpu_id", type=str, default="0")
    parser.add_argument("--checkpoint", type=str, default=None)

    parser.add_argument("--learning_rate", type=float, default=1e-2)
    parser.add_argument("--weight_decay", type=float, default=1e-2)

    args = parser.parse_args()
    args.output_file = args.output_file.format(args.classifier)
    main(args)
