import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
import math
from tqdm import tqdm
import json
import numpy as np

from utils.box_util import get_3d_box
from utils.seed import worker_init_fn, get_generator


class NormalDataloaderMixin:
    def __init__(self) -> None:
        pass

    def build_dataloader(self, dataset):
        seed = getattr(self, "seed", None)
        generator = get_generator(seed) if seed is not None else None

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=8,
            pin_memory=True,
            shuffle=True,
            drop_last=True,
            worker_init_fn=worker_init_fn if seed is not None else None,
            generator=generator,
        )
        return data_loader

    def build_train_test_dataloader(self):
        seed = getattr(self, "seed", None)
        generator = get_generator(seed) if seed is not None else None

        is_distributed = dist.is_initialized()

        if hasattr(self, "train_dataset"):
            if is_distributed:
                self.train_sampler = DistributedSampler(
                    self.train_dataset, shuffle=True, seed=seed if seed is not None else 0
                )
                shuffle = False
            else:
                self.train_sampler = None
                shuffle = True

            self.train_data_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_size=self.batch_size,
                num_workers=8,
                pin_memory=True,
                shuffle=shuffle,
                sampler=self.train_sampler,
                drop_last=True,
                worker_init_fn=worker_init_fn if seed is not None else None,
                generator=generator,
            )

        if hasattr(self, "test_dataset"):
            if is_distributed:
                self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False)
                self.test_data_loader = torch.utils.data.DataLoader(
                    self.test_dataset,
                    batch_size=self.batch_size,
                    num_workers=8,
                    pin_memory=True,
                    shuffle=False,
                    sampler=self.test_sampler,
                    worker_init_fn=worker_init_fn if seed is not None else None,
                )
            else:
                self.test_sampler = None
                self.test_data_loader = torch.utils.data.DataLoader(
                    self.test_dataset,
                    batch_size=self.batch_size,
                    num_workers=8,
                    pin_memory=True,
                    shuffle=False,
                    sampler=None,
                    worker_init_fn=worker_init_fn if seed is not None else None,
                )

    def prepare_data(self, data_dict):
        for key in data_dict:
            if torch.is_tensor(data_dict[key]):
                data_dict[key] = data_dict[key].cuda()


class ModelOptimizationMixin(object):
    def __init__(self):
        pass

    @staticmethod
    def warmup_cosine(step, warmup_step, tot_step):
        warmup_step = int(warmup_step)
        tot_step = max(int(tot_step), 1)

        if warmup_step <= 0:
            progress = min(step, tot_step)
            return max(0.5 * (1 + math.cos(progress / tot_step * math.pi)), 1e-5)

        warmup_step = min(warmup_step, tot_step)
        if step <= warmup_step:
            return step / warmup_step

        decay_denom = max(tot_step - warmup_step, 1)
        return max(0.5 * (1 + math.cos((step - warmup_step) / decay_denom * math.pi)), 1e-5)

    def no_decay_param_group(self, parameters, lr):
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        decay_params = []
        no_decay_params = []
        for n, p in parameters:
            if p.requires_grad is False:
                continue
            if not any(nd in n for nd in no_decay):
                decay_params.append(p)
            else:
                no_decay_params.append(p)
        optimizer_grouped_parameters = [
            {"params": decay_params, "weight_decay": 0.01, "lr": lr},
            {"params": no_decay_params, "weight_decay": 0.0, "lr": lr},
        ]
        return optimizer_grouped_parameters


class ModelEvaluationMixin(object):
    def eval_qa(self, epoch):
        eval_dict = {"target_metric": [], "ans1_acc": []}
        loss_dict = {"total_loss": [], "ans_loss": []}
        loss_key_map = {"answer_loss": "ans_loss"}
        tp_sum = 0.0
        fp_sum = 0.0
        fn_sum = 0.0

        if self.eval_task:
            eval_results = []

        total_count = 0
        data_iter = tqdm(self.test_data_loader) if getattr(self, "is_main_process", True) else self.test_data_loader
        for _, data_dict in enumerate(data_iter):
            data_dict = self.forward_one(data_dict)
            with torch.no_grad():
                data_dict = self.get_loss(data_dict)
            data_dict = self.get_metrics(data_dict)
            count = data_dict["obj_fts"].shape[0]
            total_count += count
            tp_sum += float(data_dict.get("ans_tp", 0.0))
            fp_sum += float(data_dict.get("ans_fp", 0.0))
            fn_sum += float(data_dict.get("ans_fn", 0.0))

            if self.eval_task:
                for j in range(count):
                    og3d_pred = torch.argmax(data_dict["og3d_logits"], dim=1)
                    box = data_dict["obj_boxes"][j, og3d_pred[j]].cpu().numpy()
                    box_center = box[0:3]
                    box_size = box[3:6]
                    pred_data = {
                        "scene_id": data_dict["scan_id"][j],
                        "question_id": data_dict["data_idx"][j],
                        "answer_top10": data_dict["answer_top10"][j],
                        "bbox": get_3d_box(box_center, box_size).tolist(),
                    }
                    eval_results.append(pred_data)

            for key in eval_dict.keys():
                eval_dict[key].append(float(data_dict[key]) * count)

            for key in loss_dict.keys():
                data_key = key
                for src_key, dst_key in loss_key_map.items():
                    if dst_key == key and src_key in data_dict:
                        data_key = src_key
                        break
                if data_key in data_dict:
                    loss_val = data_dict[data_key]
                    if isinstance(loss_val, torch.Tensor):
                        loss_val = loss_val.item()
                    loss_dict[key].append(float(loss_val) * count)

        is_distributed = dist.is_initialized() and dist.get_world_size() > 1
        if is_distributed:
            metric_keys = list(eval_dict.keys())
            loss_keys = list(loss_dict.keys())
            local_sums = [float(total_count)]
            for k in metric_keys:
                local_sums.append(np.sum(eval_dict[k]))
            for k in loss_keys:
                local_sums.append(np.sum(loss_dict[k]) if loss_dict[k] else 0.0)
            local_sums.extend([tp_sum, fp_sum, fn_sum])

            stats_tensor = torch.tensor(local_sums, device="cuda")
            dist.all_reduce(stats_tensor, op=dist.ReduceOp.SUM)

            global_total_count = stats_tensor[0].item()
            idx = 1
            for k in metric_keys:
                eval_dict[k] = stats_tensor[idx].item() / global_total_count
                idx += 1
            for k in loss_keys:
                loss_dict[k] = stats_tensor[idx].item() / global_total_count
                idx += 1
            tp_sum = stats_tensor[idx].item()
            fp_sum = stats_tensor[idx + 1].item()
            fn_sum = stats_tensor[idx + 2].item()
        else:
            for k, v in eval_dict.items():
                eval_dict[k] = np.sum(v) / total_count
            for k, v in loss_dict.items():
                loss_dict[k] = np.sum(v) / total_count if v else 0.0

        eval_dict.update(loss_dict)
        precision = tp_sum / max(tp_sum + fp_sum, 1e-10)
        recall = tp_sum / max(tp_sum + fn_sum, 1e-10)
        f1 = (2 * precision * recall) / max(precision + recall, 1e-10)
        eval_dict["Acc"] = eval_dict["ans1_acc"]
        eval_dict["Pre"] = precision
        eval_dict["Rec"] = recall
        eval_dict["F1"] = f1
        self.record_eval_step(eval_dict, epoch)

        if self.eval_task:
            with open("3dsav_result.json", "w") as f:
                json.dump(eval_results, f, indent=4)

        return eval_dict["target_metric"]


class ModelLossMixin(object):
    def get_qa_loss(self, data_dict):
        staged_cfg = self.cfg.get("staged_training", {})
        staged_enabled = staged_cfg.get("enabled", False)
        current_epoch = getattr(self, "current_epoch", 0)
        if staged_enabled:
            stage1_epochs = staged_cfg.get("stage1_epochs", 20)
            if current_epoch < stage1_epochs:
                answer_loss_weight = staged_cfg.get("stage1_answer_weight", 1.0)
                aux_loss_weight = staged_cfg.get("stage1_aux_weight", 1.0)
            else:
                answer_loss_weight = staged_cfg.get("stage2_answer_weight", 30.0)
                aux_loss_weight = staged_cfg.get("stage2_aux_weight", 0.3)
        else:
            answer_loss_weight = self.cfg.get("qa_loss", {}).get("args", {}).get("answer_loss_weight", 1.0)
            aux_loss_weight = self.cfg.get("qa_loss", {}).get("args", {}).get("aux_loss_weight", 1.0)

        total_loss, og3d_loss, txt_qr_loss, obj_qr_raw_loss, obj_qr_pre_loss, obj_qr_post_loss, answer_loss = self.qa_loss(
            data_dict["txt_qr_logits"],
            data_dict["obj_qr_post_logits"],
            data_dict["obj_qr_pre_logits"],
            data_dict["obj_qr_raw_logits"],
            data_dict["og3d_logits"],
            data_dict["answer_scores"],
            data_dict["tgt_object_label"],
            data_dict["tgt_object_id"],
            data_dict["obj_labels"],
            data_dict["obj_masks"],
            data_dict["answer_label"],
        )

        if staged_enabled:
            aux_loss = og3d_loss + txt_qr_loss + obj_qr_raw_loss + obj_qr_pre_loss + obj_qr_post_loss
            total_loss = aux_loss_weight * aux_loss + answer_loss_weight * answer_loss

        qa_head = self.qa_head.module if hasattr(self.qa_head, "module") else self.qa_head

        lpss_entropy_loss = torch.tensor(0.0, device=total_loss.device)
        lpss_load_balance_loss = torch.tensor(0.0, device=total_loss.device)
        lpss_orthogonal_loss = torch.tensor(0.0, device=total_loss.device)

        lpss_config = self.cfg.get("lpss", {})
        lpss_entropy_weight = lpss_config.get("entropy_weight", 0.0)
        lpss_min_entropy = lpss_config.get("min_entropy", 0.5)
        lpss_target_entropy = 0.8
        lpss_min_entropy_multiplier = lpss_config.get("min_entropy_multiplier", 10.0)
        lpss_load_balance_weight = lpss_config.get("load_balance_weight", 0.0)
        lpss_orthogonal_weight = lpss_config.get("orthogonal_weight", 0.0)

        if lpss_entropy_weight > 0 and hasattr(qa_head, "lpss") and qa_head.lpss is not None:
            lpss_info = data_dict.get("lpss_info", None)
            if lpss_info is not None and "routing_entropy_tensor" in lpss_info:
                routing_entropy_tensor = lpss_info["routing_entropy_tensor"]
                routing_entropy = routing_entropy_tensor.item()
                if routing_entropy < lpss_min_entropy:
                    effective_weight = lpss_entropy_weight * lpss_min_entropy_multiplier
                    lpss_entropy_loss = effective_weight * (lpss_min_entropy - routing_entropy_tensor)
                elif routing_entropy < lpss_target_entropy:
                    lpss_entropy_loss = lpss_entropy_weight * (lpss_target_entropy - routing_entropy_tensor)
                total_loss = total_loss + lpss_entropy_loss

        if lpss_load_balance_weight > 0 and hasattr(qa_head, "lpss") and qa_head.lpss is not None:
            lpss_info = data_dict.get("lpss_info", None)
            if lpss_info is not None and "load_balance_loss" in lpss_info:
                lpss_load_balance_loss = lpss_load_balance_weight * lpss_info["load_balance_loss"]
                total_loss = total_loss + lpss_load_balance_loss

        if lpss_orthogonal_weight > 0 and hasattr(qa_head, "lpss") and qa_head.lpss is not None:
            lpss_info = data_dict.get("lpss_info", None)
            if lpss_info is not None and "orthogonal_loss" in lpss_info:
                lpss_orthogonal_loss = lpss_orthogonal_weight * lpss_info["orthogonal_loss"]
                total_loss = total_loss + lpss_orthogonal_loss

        self_diff_lpss_aux_loss = torch.tensor(0.0, device=total_loss.device)
        self_diff_cfg = self.cfg.get("self_diff_lpss", {})
        if self_diff_cfg.get("enabled", False):
            unified_encoder = self.unified_encoder.module if hasattr(self.unified_encoder, "module") else self.unified_encoder
            if hasattr(unified_encoder, "use_self_diff_lpss") and unified_encoder.use_self_diff_lpss:
                aux_loss = unified_encoder.get_self_diff_lpss_auxiliary_loss()
                if aux_loss is not None:
                    self_diff_lpss_aux_loss = aux_loss
                    total_loss = total_loss + self_diff_lpss_aux_loss

        data_dict["total_loss"] = total_loss
        data_dict["og3d_loss"] = og3d_loss
        data_dict["txt_qr_loss"] = txt_qr_loss
        data_dict["obj_qr_raw_loss"] = obj_qr_raw_loss
        data_dict["obj_qr_pre_loss"] = obj_qr_pre_loss
        data_dict["obj_qr_post_loss"] = obj_qr_post_loss
        data_dict["answer_loss"] = answer_loss
        data_dict["lpss_entropy_loss"] = lpss_entropy_loss
        data_dict["lpss_load_balance_loss"] = lpss_load_balance_loss
        data_dict["lpss_orthogonal_loss"] = lpss_orthogonal_loss
        data_dict["self_diff_lpss_aux_loss"] = self_diff_lpss_aux_loss
        return data_dict


class ModelMetricMixin(object):
    def get_qa_metrics(self, data_dict):
        og3d_argmax = torch.argmax(data_dict["og3d_logits"], dim=1)
        txt_argmax = torch.argmax(data_dict["txt_qr_logits"], dim=1)
        data_dict["og_acc"] = 0.0
        data_dict["txt_acc"] = 0.0
        for i in range(data_dict["tgt_object_id"].shape[0]):
            data_dict["og_acc"] += data_dict["tgt_object_id"][i, og3d_argmax[i]]
            data_dict["txt_acc"] += data_dict["tgt_object_label"][i, txt_argmax[i]]
        data_dict["og_acc"] /= float(data_dict["tgt_object_id"].shape[0])
        data_dict["txt_acc"] /= float(data_dict["tgt_object_label"].shape[0])

        num_answers = data_dict["answer_scores"].shape[-1]
        top_k = min(10, num_answers)
        choice_1 = data_dict["answer_scores"].argmax(dim=-1)
        choice_10 = torch.topk(data_dict["answer_scores"].detach(), top_k, -1)[1]
        correct1 = 0
        correct10 = 0
        for i in range(data_dict["answer_label"].shape[0]):
            if data_dict["answer_label"][i, choice_1[i]] == 1:
                correct1 += 1
            for j in range(top_k):
                if data_dict["answer_label"][i, choice_10[i, j]] == 1:
                    correct10 += 1
                    break
        data_dict["ans1_acc"] = correct1 / float(len(choice_1))
        data_dict["ans10_acc"] = correct10 / float(len(choice_1))
        data_dict["answer_top10"] = [
            [
                self.train_dataset.dataset.answer_vocab.itos(choice_10[i, j].item())
                for j in range(top_k)
            ]
            for i in range(choice_10.shape[0])
        ]

        batch_indices = torch.arange(choice_1.shape[0], device=choice_1.device)
        answer_label = data_dict["answer_label"] > 0
        tp = answer_label[batch_indices, choice_1].sum().item()
        fp = float(choice_1.numel() - tp)
        fn = float(answer_label.sum().item() - tp)
        precision = tp / max(tp + fp, 1e-10)
        recall = tp / max(tp + fn, 1e-10)
        f1 = (2 * precision * recall) / max(precision + recall, 1e-10)
        data_dict["ans_tp"] = tp
        data_dict["ans_fp"] = fp
        data_dict["ans_fn"] = fn
        data_dict["ans_precision"] = precision
        data_dict["ans_recall"] = recall
        data_dict["ans_f1"] = f1

        data_dict["obj_qr_post_acc"] = (
            torch.sum(
                torch.argmax(data_dict["obj_qr_post_logits"], dim=2)[data_dict["obj_masks"]]
                == data_dict["obj_labels"][data_dict["obj_masks"]]
            ).item()
            / float(data_dict["obj_masks"].sum().item() + 1e-10)
        )
        data_dict["obj_qr_pre_acc"] = (
            torch.sum(
                torch.argmax(data_dict["obj_qr_pre_logits"], dim=2)[data_dict["obj_masks"]]
                == data_dict["obj_labels"][data_dict["obj_masks"]]
            ).item()
            / float(data_dict["obj_masks"].sum().item() + 1e-10)
        )
        data_dict["obj_qr_raw_acc"] = (
            torch.sum(
                torch.argmax(data_dict["obj_qr_raw_logits"], dim=2)[data_dict["obj_masks"]]
                == data_dict["obj_labels"][data_dict["obj_masks"]]
            ).item()
            / float(data_dict["obj_masks"].sum().item() + 1e-10)
        )

        data_dict["target_metric"] = data_dict["ans1_acc"]
        return data_dict
