import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

torch.autograd.set_detect_anomaly(True)
from models.idx2word import Idx2Word
import numpy as np
from pathlib import Path
import os
from torch.utils.data import DataLoader
from typing import Dict, List


class MLPClassifier(nn.Module):
    def __init__(self, input_dim, latent_dim, output_dim, n_layers, dropout_rate):
        super(MLPClassifier, self).__init__()

        layers = []
        layers.append(nn.Linear(input_dim, latent_dim))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(latent_dim))
        layers.append(nn.Dropout(dropout_rate))
        for _ in range(n_layers - 1):
            layers.append(nn.Linear(latent_dim, latent_dim))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(latent_dim))
            layers.append(nn.Dropout(dropout_rate))
        layers.append(nn.Linear(latent_dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        logits = self.net(x)
        return logits


# If other_name (other_relation) is true, then we consider an extra type capturing all other types
#TODO control training in CUDA/GPU/MP4
class SceneGraphModel(nn.Module):
    def __init__(
        self,
        feat_dim,
        meta_info,
        model_dir,
        *,
        models=None,
        other_name=False,
        other_relation=False
    ):
        super(SceneGraphModel, self).__init__()
        self.feat_dim = feat_dim
        self.n_names = meta_info["name"]["num"]
        self.n_attrs = meta_info["attr"]["num"]
        self.n_rels = meta_info["rel"]["num"]
        self.meta_info = meta_info

        if other_name:
            self.n_names += 1
        if other_relation:
            self.n_rels += 1

        if models is None:
            self.name_clf = MLPClassifier(
                input_dim=self.feat_dim,
                output_dim=self.n_names,
                latent_dim=1024,
                n_layers=2,
                dropout_rate=0.3,
            )

            self.rela_clf = MLPClassifier(
                input_dim=(self.feat_dim + 4) * 2,  # 4: bbox
                output_dim=self.n_rels,  # 1: None
                latent_dim=1024,
                n_layers=1,
                dropout_rate=0.5,
            )

            self.attr_clf = MLPClassifier(
                input_dim=self.feat_dim,
                output_dim=self.n_attrs,
                latent_dim=1024,
                n_layers=1,
                dropout_rate=0.3,
            )
            if model_dir is not None:
                self.load(model_dir)
        else:
            self.name_clf = models["name"]
            self.rela_clf = models["rela"]
            self.attr_clf = models["attr"]

    def load(self, model_dir):
        model_f = model_dir + "/%s_best_epoch.pt" % "name"
        print("loading model from %s" % model_f)
        self.name_clf.load_state_dict(torch.load(model_f))
        self.name_clf.eval()

        model_f = model_dir + "/%s_best_epoch.pt" % "relation"
        print("loading model from %s" % model_f)
        self.rela_clf.load_state_dict(torch.load(model_f))
        self.rela_clf.eval()

        model_f = model_dir + "/%s_best_epoch.pt" % "attribute"
        print("loading model from %s" % model_f)
        self.attr_clf.load_state_dict(torch.load(model_f))
        self.attr_clf.eval()

    def save(self, model_dir):
        Path(model_dir).mkdir(parents=True, exist_ok=True)
        save_f = os.path.join(model_dir, "name_best_epoch.pt")
        torch.save(self.name_clf.state_dict(), save_f)
        save_f = os.path.join(model_dir, "relation_best_epoch.pt")
        torch.save(self.rela_clf.state_dict(), save_f)
        save_f = os.path.join(model_dir, "attribute_best_epoch.pt")
        torch.save(self.attr_clf.state_dict(), save_f)

    def forward(
        self,
        obj_features,
        rela_features,
        batch_obj_split,
        batch_rela_split,
        name_cls=True,
        attr_cls=True,
        softmax=True,
    ):
        if obj_features is not None:
            if attr_cls:
                attr_logits = self.attr_clf(obj_features)
                attr_probs = torch.sigmoid(attr_logits)
                '''
                if softmax:
                    attr_probs = torch.sigmoid(attr_logits)
                else:
                    attr_probs = attr_logits'
                '''
            else:
                attr_logits = None
                attr_probs = None

            if name_cls:
                name_logits = self.name_clf(obj_features)
                if batch_obj_split:
                    current_split = 0
                    name_probs = []
                    '''
                    for split in batch_obj_split:
                        current_logits = name_logits[current_split:split]
                        if softmax:
                            current_probs = F.softmax(current_logits, dim=1)
                        else:
                            current_probs = current_logits
                        name_probs.append(current_probs)
                        current_split = split
                    name_probs = torch.cat(name_probs).reshape(obj_features.shape[0], -1)
                    '''
                    batched_logits = []
                    for split in batch_obj_split:
                        current_logits = name_logits[current_split:split]
                        batched_logits.append(current_logits)
                        name_probs.append(F.softmax(current_logits, dim=1))
                        current_split = split
                    name_probs = torch.cat(name_probs).reshape(obj_features.shape[0], -1)
                    name_logits = torch.cat(batched_logits).reshape(obj_features.shape[0], -1)
                else:
                    '''
                    if softmax:
                        name_probs = F.softmax(name_logits, dim=1)
                    else:
                        name_probs = name_logits'
                    '''
                    name_probs = F.softmax(name_logits, dim=1)
            else:
                name_probs = None
                name_logits = None
        else:
            name_probs = None
            name_logits = None
            attr_probs = None
            attr_logits = None

        if rela_features is not None:
            rela_logits = self.rela_clf(rela_features)
            if batch_rela_split:
                current_split = 0
                rela_probs = []
                '''
                for split in batch_rela_split:
                    current_logits = rela_logits[current_split:split]
                    if softmax:
                        current_probs = F.softmax(current_logits, dim=1)
                    else:
                        current_probs = current_logits
                    rela_probs.append(current_probs)
                    current_split = split
                rela_probs = torch.cat(rela_probs).reshape(rela_features.shape[0], -1)
                '''
                batched_logits = []
                for split in batch_rela_split:
                    current_logits = rela_logits[current_split:split]
                    batched_logits.append(current_logits)
                    rela_probs.append(F.softmax(current_logits, dim=1))
                    current_split = split
                rela_probs = torch.cat(rela_probs).reshape(rela_features.shape[0], -1)
                rela_logits = torch.cat(batched_logits).reshape(rela_features.shape[0], -1)
            else:
                '''
                if softmax:
                    rela_probs = F.softmax(rela_logits, dim=1)
                else:
                    rela_probs = rela_logits'
                '''
                rela_probs = F.softmax(rela_logits, dim=1)
        else:
            rela_probs = None
            rela_logits = None
        return name_probs, name_logits, attr_probs, attr_logits, rela_probs, rela_logits

# TODO correctly compute the accuracy for the attr predicate
def test_SceneGraphModel(
    sg_model: SceneGraphModel,
    test_loader: DataLoader,
    test_samples: List,
    test_scene_graphs_and_features: Dict,
    idx2word: Idx2Word,
    gpu: int = -1,
) -> float:
    # set up testing mode
    sg_model.eval()
    device = torch.device(f"cuda:{gpu}" if gpu >= 0 else "cpu")

    # check if each single prediction is correct
    singleCorrect = 0
    singleTotal = 0

    singleCorrectName = 0
    singleTotalName = 0

    singleCorrectAttr = 0
    singleTotalAttr = 0

    singleCorrectRela = 0
    singleTotalRela = 0
    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Testing")
        for task_ids in pbar:
            for task_id in task_ids:
                image_id, t, objects, target = test_samples[task_id]
                image_metadata = test_scene_graphs_and_features[image_id]
                object_ids = image_metadata["object_ids"]
                object_features = image_metadata["object_feature"]
                bboxes = image_metadata["scene_graph"]["bboxes"]
                # Create a map from each bounding box to its feature vector 
                object_feature_dict = {
                    str(object_id): feature
                    for object_id, feature in zip(object_ids, object_features)
                }

                X = None
                Y = None
                if t == "rela":
                    o1, o2 = objects
                    sub_feat_np_array = object_feature_dict[str(o1)]
                    obj_feat_np_array = object_feature_dict[str(o2)]
                    sub_bbox_np_array = np.array(bboxes[int(o1)])
                    obj_bbox_np_array = np.array(bboxes[int(o2)])
                    rela_features = np.concatenate(
                        [
                            sub_feat_np_array,
                            obj_feat_np_array,
                            sub_bbox_np_array,
                            obj_bbox_np_array,
                        ]
                    )

                    rela_features = torch.cat(
                        [torch.from_numpy(x).float() for x in [rela_features]]
                    ).reshape(len([rela_features]), -1)
                    Y = rela_features.to(device)

                    name_cls = False
                    attr_cls = False
                else:
                    o1 = objects
                    obj_features = object_feature_dict[str(o1)]
                    obj_features = torch.cat(
                        [torch.from_numpy(x).float() for x in [obj_features]]
                    ).reshape(len([obj_features]), -1)
                    X = obj_features.to(device)

                    name_cls = False
                    attr_cls = False
                    if t == "name":
                        name_cls = True
                    elif t == "attr":
                        attr_cls = True

                name_probs, _, attr_probs, _, rela_probs, _ = sg_model.forward(
                    X, Y, [], [], name_cls, attr_cls, softmax=False
                )

                name_probs, attr_probs, rela_probs = name_probs.cpu() if name_probs is not None else None, attr_probs.cpu() if attr_probs is not None else None, rela_probs.cpu() if rela_probs is not None else None

                if t == "name":
                    pred = name_probs.argmax(dim=-1)
                    pred = pred.item()
                    pred_name = idx2word.idx_to_name(pred)
                    assert pred_name is not None
                    if pred_name == target:
                        singleCorrect += 1
                        singleCorrectName += 1
                    singleTotalName += 1
                elif t == "attr":
                    indices = (attr_probs.flatten() > 0.5 ).nonzero()
                    indices = [i.item() for i in indices]
                    pred_names = [idx2word.idx_to_attr(i) for i in indices]
                    for pred_name in pred_names:
                        if pred_name in target:
                            singleCorrect += 1
                            singleCorrectAttr += 1
                    singleTotalAttr += 1
                elif t == "rela":
                    if rela_probs is not None and rela_probs.numel() > 0:
                        pred = rela_probs.argmax(dim=-1)
                        pred = pred.item()
                        pred_name = idx2word.idx_to_rela(pred)
                        assert pred_name is not None
                        if pred_name == target:
                            singleCorrect += 1
                            singleCorrectRela += 1
                    else:
                        # Handle case where rela_probs is None or empty
                        print(f"Warning: rela_probs is None or empty for task_id {task_id}")
                    singleTotalRela += 1
                singleTotal += 1

                pbar.set_description(
                    f"Testing: Name: {singleCorrectName}/{singleTotalName}, Attr: {singleCorrectAttr}/{singleTotalAttr}, Rela: {singleCorrectRela}/{singleTotalRela}, Total: {singleCorrect}/{singleTotal}"
                )

    singleAccuracyName = 0
    singleAccuracyAttr = 0
    singleAccuracyRela = 0
    if singleTotalName != 0:
        singleAccuracyName = 100.0 * singleCorrectName / singleTotalName
    if singleTotalAttr != 0:
        singleAccuracyAttr = 100.0 * singleCorrectAttr / singleTotalAttr
    if singleTotalRela != 0:
        singleAccuracyRela = 100.0 * singleCorrectRela / singleTotalRela
    singleAccuracy = 100.0 * singleCorrect / singleTotal
    print(
        f"Single accuracy: {singleAccuracy:.2f}%, "
        f"Name accuracy: {singleAccuracyName:.2f}%, "
        f"Attr accuracy: {singleAccuracyAttr:.2f}%, "
        f"Rela accuracy: {singleAccuracyRela:.2f}%"
    )
    return singleAccuracyName, singleAccuracyAttr, singleAccuracyRela, singleAccuracy
