import json
import os
from argparse import ArgumentParser

import numpy as np
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from tqdm import tqdm

from dataset.attribute_data import AttributeImageNetDataModule
from dataset.data import ImageNetDataModule
from models.module import ImageNetExplainModule
from parse_tree import Parse_Tree


class SaveFeaturesHook:
    def __init__(self):
        self.features = []
        self.enabled = False

    def __call__(self, module, input, output):
        if self.enabled:
            self.features.append(output.detach().view(output.size(0), -1))

    def reset(self):
        self.features = []

    def enable(self):
        self.enabled = True

    def disable(self):
        self.enabled = False


class ViTFeaturesHook:
    def __init__(self):
        self.features = []
        self.enabled = False

    def __call__(self, module, input, output):
        if self.enabled:
            # class token
            output = output[:, 0].detach().view(output.size(0), -1).cpu()
            self.features.append(output)

    def reset(self):
        self.features = []

    def enable(self):
        self.enabled = True

    def disable(self):
        self.enabled = False


class SaveFeaturesInputHook:
    def __init__(self):
        self.features = []
        self.enabled = False

    def __call__(self, module, input, output):
        if self.enabled:
            self.features.append(input[0].detach().view(input[0].size(0), -1))

    def reset(self):
        self.features = []

    def enable(self):
        self.enabled = True

    def disable(self):
        self.enabled = False


def test_classifier_with_tree(model, test_loader, hook, trees, device, output_file):
    # Set the model to evaluation mode
    model.eval()
    # Reset the hook
    hook.enable()

    path_dicts = []  # Initialize a list to store all path_dicts

    with torch.no_grad():
        # Iterate over each batch in the test loader
        for inputs, targets in tqdm(test_loader):
            hook.reset()
            # Make predictions using the model
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            # Iterate over each prediction and corresponding feature vector
            for p, feat, target in zip(predicted, hook.features[0], targets):
                # Find the path in the Parse_Tree for the current feature vector
                path = trees[p.item()].top_matches(feat)
                path_dict = path.to_dict()
                # Append the path_dict and target to the list
                path_dicts.append({'path_dict': path_dict, 'pred': p.item()})

    # Save the path_dicts list to a json file
    with open(output_file, "w") as f:
        json.dump(path_dicts, f)


def main(args):

    seed_everything(0)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    if args.logger == "wandb":
        logger = WandbLogger(name=args.classifier, project="cifar10")
    elif args.logger == "tensorboard":
        logger = TensorBoardLogger("cifar10", name=args.classifier)

    checkpoint = ModelCheckpoint(
        monitor="acc/val", mode="max", save_last=False)

    trainer = Trainer(
        fast_dev_run=bool(args.dev),
        logger=logger if not bool(args.dev + args.test_phase) else None,
        gpus=-1,
        deterministic=True,
        weights_summary=None,
        log_every_n_steps=1,
        max_epochs=args.max_epochs,
        checkpoint_callback=checkpoint,
        precision=args.precision,
    )

    # Create an instance of the model and register the hook
    model = ImageNetExplainModule(args)
    if args.checkpoint is not None and os.path.exists(args.checkpoint):
        print("Using Provided Weights {}".format(args.checkpoint))
        state_dict = args.checkpoint
        # model.model.load_state_dict(torch.load(state_dict)['state_dict'])
        # saved_state_dict = torch.load(state_dict)['state_dict']
        saved_state_dict = torch.load(state_dict)
        print(saved_state_dict.keys())
        # saved_state_dict = saved_state_dict['state_dict']
        saved_state_dict = saved_state_dict['teacher']

        plain_state_dict = {}
        for key, value in saved_state_dict.items():
            new_key = key.replace('module.', '')
            # new_key = new_key.replace('head.', '')
            new_key = new_key.replace('backbone.', '')
            plain_state_dict[new_key] = value

        if args.head_checkpoint is not None and os.path.exists(args.head_checkpoint):
            head_state_dict = torch.load(args.head_checkpoint)
            head_state_dict = head_state_dict['state_dict']
            head_state_dict_new = {}
            for key, value in head_state_dict.items():
                new_key = key.replace('module.', '')

                # new_key = new_key.replace('linear.', 'fc.')
                # new_key = new_key.replace('linear.', 'head.')
                # new_key = new_key.replace('fc.', 'head.')

                head_state_dict_new[new_key] = value

            plain_state_dict.update(head_state_dict_new)

        # print(model.model)

        msg = model.model.model.load_state_dict(plain_state_dict, strict=False)
        print(msg)
        msg = model.model.head.load_state_dict(
            head_state_dict_new, strict=False)
        print(msg)

    # exit()
    if 'vit' in args.classifier:
        # hook = ViTFeaturesHook()
        hook = SaveFeaturesInputHook()
        # model.model.norm.register_forward_hook(hook)

        model.model.head.linear.register_forward_hook(hook)
    else:
        hook = SaveFeaturesInputHook()
        print(f"Using {args.classifier}")
        if "repvgg" in args.classifier:
            model.model.linear.register_forward_hook(hook)
        elif "resnet" in args.classifier or "shufflenetv2" in args.classifier:
            model.model.fc.register_forward_hook(hook)
        elif "vgg" in args.classifier or "mobilenetv2" in args.classifier:
            model.model.classifier.register_forward_hook(hook)
        else:
            raise NotImplementedError(
                f"The {args.classifier} is not implemented")

    hook.enable()
    data = AttributeImageNetDataModule(args, prefix=None)
    test_data = ImageNetDataModule(args)

    test_loader = data.test_dataloader()
    trainer.test(model, test_loader)

    with open(args.prompt_file) as file:
        prompt_dict = json.load(file)

    category_to_index = test_loader.dataset.category_to_index
    trees = [None for c in test_loader.dataset.categories]
    for k, v in prompt_dict.items():
        assert isinstance(v, dict), "prompt_dict must be dict"
        c_tree = Parse_Tree.from_dict(v)
        trees[category_to_index[k]] = c_tree

    # Concatenate the feature tensors of all batches
    all_features = torch.cat(hook.features)

    # Retrieve attributes and labels from the model
    attribute = np.array(model.attributes)
    labels = torch.cat(model.labels).detach().cpu().numpy()

    # Iterate over each category and assign the feature values to the corresponding Parse_Tree object
    for i, c in enumerate(test_loader.dataset.categories):
        # Create a mask for the current category
        c_mask = labels == i

        # Retrieve the attributes and feature tensor for the current category
        c_attr = attribute[c_mask]
        c_feat = all_features[c_mask]

        # Create a dictionary to store the feature values for each attribute
        values = dict()
        for attr in set(c_attr):
            # Retrieve the feature values for the current attribute
            values[attr] = c_feat[c_attr == attr]
        trees[category_to_index[c]].set_values(values)

    test_classifier_with_tree(model.model,
                              output_file=args.output_file,
                              test_loader=test_data.test_dataloader(),
                              hook=hook,
                              trees=trees,
                              device='cuda:0')


if __name__ == "__main__":
    parser = ArgumentParser()

    # PROGRAM level args
    parser.add_argument("--data_dir", type=str, default="data/imagenet")
    parser.add_argument("--support_data_dir", type=str, default="data/cifar10")
    parser.add_argument("--test_phase", action='store_true')
    parser.add_argument("--dev", action='store_true')
    parser.add_argument(
        "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"]
    )

    # TRAINER args
    parser.add_argument("--prompt_file", type=str,
                        default="imagenet/imagenet_prompt.json")
    parser.add_argument("--output_file", type=str,
                        default="imagenet_test_tree_{}.json")
    parser.add_argument("--classifier", type=str, default="resnet18")
    parser.add_argument("--pretrained", action='store_false', default=True)
    parser.add_argument("--checkpoint",  type=str, default=None)
    parser.add_argument("--head_checkpoint",  type=str, default=None)

    parser.add_argument("--precision", type=int, default=32, choices=[16, 32])
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--max_epochs", type=int, default=100)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--gpu_id", type=str, default="0")

    parser.add_argument("--learning_rate", type=float, default=1e-2)
    parser.add_argument("--weight_decay", type=float, default=1e-2)

    args = parser.parse_args()
    args.output_file = args.output_file.format(args.classifier)
    main(args)
