import csv
import os
import random
import torch


class InverseAugmentationConfig:
    def __init__(self, aug_cfg, rel_classes):
        self.rules = {}
        self.strategy = "ewai"
        self.mode = []
        self.hbt_filter = set()
        self.super_filter = set()
        self.enabled = False

        if aug_cfg is None:
            return

        self.enabled = bool(getattr(aug_cfg, "ENABLED", False))
        if not self.enabled:
            return

        self.strategy = getattr(aug_cfg, "STRATEGY", "ewai").lower()
        self.mode = self._split(getattr(aug_cfg, "MODE", ""))
        self.hbt_filter = self._normalize_hbt_list(self._split(getattr(aug_cfg, "HBT", "")))
        self.super_filter = {s.lower() for s in self._split(getattr(aug_cfg, "SUPER", ""))}
        self.rules = {}

        if not self.enabled:
            return

        csv_path = getattr(aug_cfg, "CSV_PATH", "")
        if not csv_path or not os.path.isfile(csv_path):
            print(f"[InverseAugmentation] CSV 未找到: {csv_path}")
            self.enabled = False
            return

        rel_to_id = {name.lower(): idx for idx, name in enumerate(rel_classes)}
        try:
            with open(csv_path, encoding="utf-8-sig") as f:
                reader = csv.reader(f)
                header = next(reader, None)
                if header:
                    header = [h.strip().lower() for h in header]
            with open(csv_path, encoding="utf-8-sig") as f:
                reader = csv.DictReader(f)
                for row in reader:
                    orig = (row.get(header[0]) or "").strip().lower()
                    inverse = (row.get(header[2]) or "").strip().lower()
                    hbt = self._normalize_hbt_token(row.get(header[3]))
                    super1 = (row.get(header[4]) or "").strip()
                    super2 = (row.get(header[5]) or "").strip()

                    orig_id = rel_to_id.get(orig)
                    inv_id = rel_to_id.get(inverse)
                    if orig_id is None or inv_id is None:
                        continue

                    self.rules[orig_id] = {
                        "inverse_id": inv_id,
                        "hbt": hbt,
                        "super1": super1,
                        "super2": super2,
                    }
            if not self.rules:
                print("[InverseAugmentation] CSV 中无有效映射，禁用增强")
                self.enabled = False
        except Exception as exc:
            print(f"[InverseAugmentation] 读取 CSV 失败: {exc}")
            self.enabled = False

    def _split(self, raw):
        return [p.strip() for p in raw.split(",") if p.strip()]

    def _match_super(self, value):
        if not self.super_filter:
            return True
        value_lower = value.lower()
        if value_lower in self.super_filter:
            return True
        return any(key in value_lower for key in ("semantic", "action")) and \
            any(key in self.super_filter for key in ("semantic", "action"))

    def _should_augment(self, meta):
        if not self.mode:
            return False

        hbt_ok = (not self.hbt_filter) or (meta["hbt"] in self.hbt_filter)
        super1_ok = self._match_super(meta["super1"])
        super2_ok = self._match_super(meta["super2"])

        checks = {
            "hbt": hbt_ok,
            "super1": super1_ok,
            "super2": super2_ok,
            "hbt,super1": hbt_ok and super1_ok,
            "hbt,super2": hbt_ok and super2_ok,
        }
        return checks.get(",".join(sorted(self.mode)), False)

    def augment_relations(self, target):
        if not self.enabled or not target.has_field("relation_tuple"):
            return target

        triplets = target.get_field("relation_tuple")
        if triplets.numel() == 0:
            return target

        triplets = triplets.clone()
        added = []

        for idx in range(triplets.size(0)):
            subj, obj, rel = triplets[idx]
            rel_id = int(rel.item())
            meta = self.rules.get(rel_id)
            if not meta or not self._should_augment(meta):
                continue

            flipped = torch.tensor(
                [int(obj.item()), int(subj.item()), meta["inverse_id"]],
                device=triplets.device,
                dtype=triplets.dtype,
            )

            if self.strategy == "ewai":
                added.append(flipped.unsqueeze(0))
            else:  # tihuan
                if random.random() < 0.5:
                    triplets[idx] = flipped

        if added:
            triplets = torch.cat([triplets] + added, dim=0)

        target.add_field("relation_tuple", triplets)
        return target

    def _check_hbt(self, value):
        if self.hbt_filter is None:
            return True
        if not self.hbt_filter:
            return False
        return value.strip().lower() in self.hbt_filter

    def _normalize_hbt_list(self, values):
        if not values:
            return set()
        normalized = {self._normalize_hbt_token(v) for v in values}
        normalized.discard("")
        return normalized

    @staticmethod
    def _normalize_hbt_token(value):
        val = (value or "").strip().lower()
        mapping = {
            "1": "h", "head": "h", "h": "h",
            "2": "b", "body": "b", "b": "b",
            "3": "t", "tail": "t", "t": "t",
        }
        return mapping.get(val, val)

    def augment_sampled_pairs(self, pair_list, label_list, binary_list=None, inverse_flags_list=None):
        if not self.enabled:
            return pair_list, label_list, binary_list, inverse_flags_list

        new_pairs, new_labels = [], []
        new_binaries = [] if binary_list is not None else None
        new_inverse = []

        for i, (pairs, labels) in enumerate(zip(pair_list, label_list)):
            binaries = None if binary_list is None else binary_list[i]
            inv_flags = None if inverse_flags_list is None else inverse_flags_list[i]
            if pairs is None or labels is None or labels.numel() == 0:
                new_pairs.append(pairs)
                new_labels.append(labels)
                if new_binaries is not None:
                    new_binaries.append(binaries)
                new_inverse.append(inv_flags)
                continue

            pairs_aug = pairs.clone()
            labels_aug = labels.clone()
            has_binary = (
                isinstance(binaries, torch.Tensor)
                and binaries.size(0) == labels.size(0)
            )
            binary_aug = binaries.clone().view(-1) if has_binary else None
            if isinstance(inv_flags, torch.Tensor) and inv_flags.size(0) == labels.size(0):
                inv_aug = inv_flags.clone().to(torch.bool)
            else:
                inv_aug = torch.zeros(labels.size(0), dtype=torch.bool, device=labels.device)

            extra_pairs, extra_labels, extra_binaries, extra_inverse = [], [], [], []

            for idx in range(labels.size(0)):
                rel_id = int(labels[idx].item())
                meta = self.rules.get(rel_id)
                if not meta or not self._should_augment(meta):
                    continue

                flipped = pairs_aug[idx].flip(0)
                inverse_id = meta["inverse_id"]

                if self.strategy == "ewai":
                    extra_pairs.append(flipped.unsqueeze(0))
                    extra_labels.append(
                        labels.new_full((1,), inverse_id, dtype=labels.dtype, device=labels.device)
                    )
                    if binary_aug is not None:
                        val = binary_aug[idx] if idx < binary_aug.numel() else binary_aug.new_zeros(())
                        extra_binaries.append(val.view(1))
                    extra_inverse.append(torch.ones((1,), dtype=torch.bool, device=labels.device))
                else:
                    if random.random() < 0.5:
                        pairs_aug[idx] = flipped
                        labels_aug[idx] = inverse_id
                        inv_aug[idx] = True

            if extra_pairs:
                pairs_aug = torch.cat([pairs_aug] + extra_pairs, dim=0)
                labels_aug = torch.cat([labels_aug] + extra_labels, dim=0)
                if binary_aug is not None and extra_binaries:
                    binary_aug = torch.cat([binary_aug.view(-1)] + extra_binaries, dim=0)
                if extra_inverse:
                    inv_aug = torch.cat([inv_aug] + extra_inverse, dim=0)

            new_pairs.append(pairs_aug)
            new_labels.append(labels_aug)
            if new_binaries is not None:
                new_binaries.append(binary_aug)
            new_inverse.append(inv_aug)

        if new_binaries is None:
            new_binaries = None
        return new_pairs, new_labels, new_binaries, new_inverse