from abc import ABC, abstractmethod
import os
from datetime import timedelta
from functools import partial

import numpy as np
import torch
import torch.distributed as dist
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import AdamW
from tqdm import tqdm

from pipeline.registry import registry
from pipeline.pipeline_mixin import (
    NormalDataloaderMixin,
    ModelOptimizationMixin,
    ModelEvaluationMixin,
    ModelMetricMixin,
    ModelLossMixin,
)
from utils.experiment_manager import create_experiment_from_config
from utils.seed import set_seed_from_config


class _ExperimentSaverWrapper:
    def __init__(self, experiment_manager, cfg):
        self.exp_manager = experiment_manager
        saver_args = cfg.get("saver", {}).get("args", {})
        self.load_dir = saver_args.get("load_dir")
        self.load_name = saver_args.get("load_name")
        if self.load_dir and self.load_name:
            self.load_path = os.path.join(self.load_dir, self.load_name)
        else:
            self.load_path = None

    def save_model(self, model, is_best: bool = False):
        if self.exp_manager:
            self.exp_manager.save_checkpoint(model.state_dict(), filename="model_latest.pth")
            if is_best:
                self.exp_manager.save_checkpoint(model.state_dict(), filename="model_best.pth")

    def save_dict(self, state_dict, is_best: bool = False):
        if self.exp_manager:
            self.exp_manager.save_checkpoint(state_dict, filename="model_latest.pth")
            if is_best:
                self.exp_manager.save_checkpoint(state_dict, filename="model_best.pth")

    def restore_dict(self):
        if self.load_path and os.path.exists(self.load_path):
            return torch.load(self.load_path)
        raise FileNotFoundError(f"Pretrained weights not found: {self.load_path}")


class Pipeline(ABC):
    @abstractmethod
    def initialize(self):
        pass

    def end(self):
        pass

    @abstractmethod
    def run(self):
        pass

    @abstractmethod
    def end(self):
        pass

    def run_all(self):
        self.initialize()
        self.run()
        self.end()


class OptimusPrimePipeline(
    Pipeline,
    NormalDataloaderMixin,
    ModelOptimizationMixin,
    ModelEvaluationMixin,
    ModelMetricMixin,
    ModelLossMixin,
):
    def __init__(self, cfg):
        self.seed_info = set_seed_from_config(cfg, verbose=True)
        self.seed = self.seed_info.get("seed") if self.seed_info else None

        self.cfg = cfg
        world_size_env = int(os.environ.get("WORLD_SIZE", 1))
        distributed_config = cfg.get("distributed", False)
        self.use_distributed = distributed_config or world_size_env > 1

        if self.use_distributed:
            self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
            self.world_size = int(os.environ.get("WORLD_SIZE", 1))
            self.rank = int(os.environ.get("RANK", 0))
            if not dist.is_initialized():
                dist.init_process_group(backend="nccl", timeout=timedelta(minutes=30))
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device(f"cuda:{self.local_rank}")
        else:
            self.local_rank = 0
            self.world_size = 1
            self.rank = 0
            self.device = torch.device("cuda:0")

        self.is_main_process = self.rank == 0

        self.use_experiment_manager = "experiment" in cfg
        self.experiment_manager = None
        if self.use_experiment_manager and self.is_main_process:
            self.experiment_manager = create_experiment_from_config(cfg)
            self.exp_dir = self.experiment_manager.exp_dir
            print(f"[Pipeline] Using unified experiment manager: {self.exp_dir}")

        if not cfg["eval_task"] and self.is_main_process:
            if self.use_experiment_manager:
                self.logger = self.experiment_manager
            else:
                self.logger = registry.get_utils(cfg["logger"]["name"])(cfg)
        else:
            self.logger = None

        if self.use_experiment_manager:
            self.saver = _ExperimentSaverWrapper(self.experiment_manager, cfg)
        else:
            self.saver = registry.get_utils(cfg["saver"]["name"])(**cfg["saver"]["args"])

        self.lang_encoder = registry.get_language_model(cfg["lang_encoder"]["name"])(**cfg["lang_encoder"]["args"]).to(
            self.device
        )
        self.point_encoder = registry.get_vision_model(cfg["point_encoder"]["name"])(**cfg["point_encoder"]["args"]).to(
            self.device
        )
        self.unified_encoder = registry.get_vision_model(cfg["unified_encoder"]["name"])(**cfg["unified_encoder"]["args"]).to(
            self.device
        )
        self.ground_head = registry.get_other_model(cfg["ground_head"]["name"])(**cfg["ground_head"]["args"]).to(
            self.device
        )
        self.qa_head = registry.get_other_model(cfg["qa_head"]["name"])(**cfg["qa_head"]["args"]).to(self.device)

        self.task = cfg["task"]
        self.eval_task = cfg["eval_task"]
        if self.task != "3dsav":
            raise ValueError("Only task '3dsav' is supported in this trimmed training pipeline.")

        if self.use_distributed:
            ddp_cfg = cfg.get("ddp", {})
            find_unused = ddp_cfg.get("find_unused_parameters", False)

            self.lang_encoder = torch.nn.parallel.DistributedDataParallel(
                self.lang_encoder, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=find_unused
            )
            self.point_encoder = torch.nn.parallel.DistributedDataParallel(
                self.point_encoder, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=find_unused
            )
            self.unified_encoder = torch.nn.parallel.DistributedDataParallel(
                self.unified_encoder, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=find_unused
            )
            self.ground_head = torch.nn.parallel.DistributedDataParallel(
                self.ground_head, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=find_unused
            )
            self.qa_head = torch.nn.parallel.DistributedDataParallel(
                self.qa_head, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=find_unused
            )

        qa_args = cfg["qa_dataset"]["args"]
        if qa_args.get("split") is not None:
            assert self.eval_task is True
            self.train_dataset = registry.get_dataset(cfg["qa_dataset"]["name"])(**qa_args)
            self.test_dataset = registry.get_dataset(cfg["qa_dataset"]["name"])(**qa_args)
        else:
            self.train_dataset = registry.get_dataset(cfg["qa_dataset"]["name"])(split="train", **qa_args)
            self.test_dataset = registry.get_dataset(cfg["qa_dataset"]["name"])(split="val", **qa_args)

        self.batch_size = cfg["batch_size"]
        self.learning_rate = cfg["learning_rate"]
        self.grad_norm = cfg["grad_norm"]
        self.epochs = cfg["epochs"]
        self.warmup_steps = cfg["warmup_steps"]

        self.build_train_test_dataloader()

        optimizer_grouped_parameters = []
        optimizer_grouped_parameters += self.no_decay_param_group(
            self.lang_encoder.named_parameters(), self.learning_rate * cfg["lang_lr_mul"]
        )
        optimizer_grouped_parameters += self.no_decay_param_group(
            self.point_encoder.named_parameters(), self.learning_rate * cfg["point_lr_mul"]
        )

        self_diff_lpss_lr_mul = cfg.get("self_diff_lpss", {}).get("lr_multiplier", 5.0)
        unified_encoder_module = self.unified_encoder.module if hasattr(self.unified_encoder, "module") else self.unified_encoder
        has_self_diff_lpss = (
            hasattr(unified_encoder_module, "self_diff_lpss") and unified_encoder_module.self_diff_lpss is not None
        )
        effective_early_lr_mul = self_diff_lpss_lr_mul if has_self_diff_lpss else 1.0
        need_separate_early_lpss = has_self_diff_lpss and self_diff_lpss_lr_mul != 1.0

        if need_separate_early_lpss:
            lpss_like_params = []
            other_params = []
            for name, param in self.unified_encoder.named_parameters():
                clean_name = name.replace("module.", "")
                if "self_diff_lpss" in clean_name:
                    lpss_like_params.append((name, param))
                else:
                    other_params.append((name, param))
            optimizer_grouped_parameters += self.no_decay_param_group(other_params, self.learning_rate * cfg["unified_lr_mul"])
            optimizer_grouped_parameters += self.no_decay_param_group(
                lpss_like_params, self.learning_rate * effective_early_lr_mul
            )
        else:
            optimizer_grouped_parameters += self.no_decay_param_group(
                self.unified_encoder.named_parameters(), self.learning_rate * cfg["unified_lr_mul"]
            )

        optimizer_grouped_parameters += self.no_decay_param_group(self.ground_head.named_parameters(), self.learning_rate)

        lpss_lr_mul = cfg.get("lpss", {}).get("lr_multiplier", 1.0)
        effective_lpss_lr_mul = max(lpss_lr_mul, self_diff_lpss_lr_mul)
        qa_head_module = self.qa_head.module if hasattr(self.qa_head, "module") else self.qa_head
        has_lpss = hasattr(qa_head_module, "lpss") and qa_head_module.lpss is not None

        if has_lpss and effective_lpss_lr_mul != 1.0:
            lpss_param_names = {f"lpss.{name}" for name, _ in qa_head_module.lpss.named_parameters()}
            qa_lpss_params = []
            qa_other_params = []
            for name, param in self.qa_head.named_parameters():
                clean_name = name.replace("module.", "")
                if any(lpss_name in clean_name for lpss_name in lpss_param_names):
                    qa_lpss_params.append((name, param))
                else:
                    qa_other_params.append((name, param))
            optimizer_grouped_parameters += self.no_decay_param_group(qa_other_params, self.learning_rate)
            optimizer_grouped_parameters += self.no_decay_param_group(
                qa_lpss_params, self.learning_rate * effective_lpss_lr_mul
            )
        else:
            optimizer_grouped_parameters += self.no_decay_param_group(self.qa_head.named_parameters(), self.learning_rate)

        self.optimizer = AdamW(optimizer_grouped_parameters, betas=[cfg["beta1"], cfg["beta2"]])
        self.parameters = []
        for group in optimizer_grouped_parameters:
            self.parameters.extend(group["params"])

        total_steps = self.epochs * len(self.train_data_loader)
        self.total_steps = total_steps
        lambda_warmup_cosine = lambda step: self.warmup_cosine(step, self.warmup_steps, total_steps)
        self.scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda_warmup_cosine)

        loss_name = cfg["qa_loss"]["name"]
        loss_fn = registry.get_optimizer(loss_name)
        loss_args = cfg["qa_loss"].get("args", {})
        self.qa_loss = partial(loss_fn, **loss_args) if loss_args else loss_fn

        if cfg["restore_model"]:
            self.restore_model()

    def initialize(self):
        pass

    def run(self):
        best_target_metric = -np.inf
        for epoch in range(self.epochs):
            self.train(epoch)
            target_metric = self.eval(epoch)
            if self.is_main_process:
                if target_metric >= best_target_metric:
                    best_target_metric = target_metric
                    self.save_model(is_best=True)
                else:
                    self.save_model(is_best=False)

    def train(self, epoch):
        self.set_model_state("train")
        self.current_epoch = epoch
        if hasattr(self, "train_sampler") and self.train_sampler is not None:
            self.train_sampler.set_epoch(epoch)

        data_iter = tqdm(self.train_data_loader) if self.is_main_process else self.train_data_loader
        for i, data_dict in enumerate(data_iter):
            data_dict["cur_step"] = epoch * len(self.train_data_loader) + i
            data_dict["total_steps"] = self.total_steps
            data_dict = self.forward_one(data_dict)
            data_dict = self.get_loss(data_dict)
            data_dict = self.get_metrics(data_dict)

            loss = data_dict["total_loss"]
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters, self.grad_norm)
            data_dict["grad_norm"] = grad_norm
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

            step = epoch * len(self.train_data_loader) + i
            self.record_train_step(data_dict, step)

    def eval(self, epoch):
        print("start evaluation on test set")
        self.set_model_state("eval")
        return self.eval_qa(epoch)

    def forward_one(self, data_dict):
        self.prepare_data(data_dict)
        if "cur_step" not in data_dict.keys():
            data_dict["cur_step"] = 1
            data_dict["total_steps"] = 1

        lang_basic_features = self.lang_encoder(data_dict["txt_ids"], data_dict["txt_masks"]).last_hidden_state
        point_basic_features, point_features_pre, obj_qr_raw_logits = self.point_encoder(
            data_dict["obj_fts"].float(),
            data_dict["obj_locs"],
            data_dict["obj_masks"],
            data_dict["obj_sem_masks"],
            data_dict["obj_labels"],
            data_dict["cur_step"],
            data_dict["total_steps"],
        )

        language_fuse_feature, point_fuse_feature = self.unified_encoder(
            lang_basic_features, data_dict["txt_masks"], point_basic_features, data_dict["obj_locs"], data_dict["obj_masks"]
        )

        txt_qr_logits, obj_qr_post_logits, obj_qr_pre_logits, og3d_logits = self.ground_head(
            language_fuse_feature, point_fuse_feature, point_features_pre, data_dict["obj_masks"]
        )

        unified_mod = self.unified_encoder.module if hasattr(self.unified_encoder, "module") else self.unified_encoder
        early_lpss_info = unified_mod.get_last_early_lpss_info()
        if early_lpss_info is not None:
            data_dict["early_lpss_info"] = early_lpss_info

        qa_head_mod = self.qa_head.module if hasattr(self.qa_head, "module") else self.qa_head
        use_lpss = hasattr(qa_head_mod, "lpss") and qa_head_mod.lpss is not None

        if use_lpss:
            answer_scores, lpss_info = self.qa_head(
                point_fuse_feature,
                data_dict["obj_masks"],
                language_fuse_feature,
                data_dict["txt_masks"],
                return_lpss_info=True,
            )
            data_dict["lpss_info"] = lpss_info
        else:
            answer_scores = self.qa_head(
                point_fuse_feature, data_dict["obj_masks"], language_fuse_feature, data_dict["txt_masks"]
            )

        data_dict["answer_scores"] = answer_scores
        data_dict["txt_qr_logits"] = txt_qr_logits
        data_dict["obj_qr_post_logits"] = obj_qr_post_logits
        data_dict["obj_qr_pre_logits"] = obj_qr_pre_logits
        data_dict["obj_qr_raw_logits"] = obj_qr_raw_logits
        data_dict["og3d_logits"] = og3d_logits

        return data_dict

    def get_loss(self, data_dict):
        return self.get_qa_loss(data_dict)

    def get_metrics(self, data_dict):
        return self.get_qa_metrics(data_dict)

    def record_train_step(self, data_dict, step):
        if not self.logger or not self.is_main_process:
            return
        log_dict = {
            "train/acc": data_dict["ans1_acc"],
            "train/Acc": data_dict["ans1_acc"],
            "train/Pre": data_dict.get("ans_precision", 0.0),
            "train/Rec": data_dict.get("ans_recall", 0.0),
            "train/F1": data_dict.get("ans_f1", 0.0),
        }
        if "lpss_entropy_loss" in data_dict and data_dict["lpss_entropy_loss"].item() != 0:
            log_dict["train/lpss_entropy_loss"] = data_dict["lpss_entropy_loss"].item()
        if "lpss_load_balance_loss" in data_dict and data_dict["lpss_load_balance_loss"].item() != 0:
            log_dict["train/lpss_load_balance_loss"] = data_dict["lpss_load_balance_loss"].item()
        if "lpss_orthogonal_loss" in data_dict and data_dict["lpss_orthogonal_loss"].item() != 0:
            log_dict["train/lpss_orthogonal_loss"] = data_dict["lpss_orthogonal_loss"].item()
        if "self_diff_lpss_aux_loss" in data_dict and data_dict["self_diff_lpss_aux_loss"].item() != 0:
            log_dict["train/self_diff_lpss_aux_loss"] = data_dict["self_diff_lpss_aux_loss"].item()
        self.logger.log(log_dict, step=step)

    def record_eval_step(self, eval_dict, epoch):
        for key in ("Acc", "Pre", "Rec", "F1"):
            if key not in eval_dict:
                continue
            print("test_" + key, eval_dict[key])
            if self.logger and self.is_main_process:
                self.logger.log({"test/" + key: eval_dict[key]}, step=(epoch + 1) * len(self.train_data_loader))

    def restore_model(self):
        state_dict = self.saver.restore_dict()

        def remap_state_dict_keys(weights, rename_rules):
            if not rename_rules:
                return weights
            remapped = {}
            for key, value in weights.items():
                new_key = key
                for old_prefix, new_prefix in rename_rules:
                    if key.startswith(old_prefix):
                        new_key = new_prefix + key[len(old_prefix):]
                        break
                remapped[new_key] = value
            return remapped

        def load_state_dict_safe(model, state_dict_key, filter_keys=None, rename_rules=None):
            weights = state_dict[state_dict_key]
            if filter_keys:
                weights = {k: v for k, v in weights.items() if not any(fk in k for fk in filter_keys)}
            weights = remap_state_dict_keys(weights, rename_rules)
            if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
                model.module.load_state_dict(weights, strict=False)
            else:
                model.load_state_dict(weights, strict=False)

        load_state_dict_safe(self.lang_encoder, "lang_encoder", filter_keys=["pooler"])
        load_state_dict_safe(
            self.point_encoder,
            "point_encoder",
            rename_rules=[
                ("point_cls_head.", "point_qr_head."),
                ("sem_cls_embed_layer.", "sem_qr_embed_layer."),
                ("obj3d_clf_pre_head.", "obj3d_qr_pre_head."),
            ],
        )
        load_state_dict_safe(self.unified_encoder, "unified_encoder")
        load_state_dict_safe(
            self.ground_head,
            "ground_head",
            rename_rules=[
                ("txt_clf_head.", "txt_qr_head."),
                ("obj3d_clf_head.", "obj3d_qr_head."),
                ("obj3d_clf_pre_head.", "obj3d_qr_pre_head."),
            ],
        )
        try:
            load_state_dict_safe(self.qa_head, "qa_head", rename_rules=[("answer_cls.", "answer_qr.")])
        except Exception:
            if self.is_main_process:
                print("fail to load qa params")

    def save_model(self, is_best: bool = False):
        if not self.is_main_process:
            return

        lang_encoder = self.lang_encoder.module if hasattr(self.lang_encoder, "module") else self.lang_encoder
        point_encoder = self.point_encoder.module if hasattr(self.point_encoder, "module") else self.point_encoder
        unified_encoder = self.unified_encoder.module if hasattr(self.unified_encoder, "module") else self.unified_encoder
        ground_head = self.ground_head.module if hasattr(self.ground_head, "module") else self.ground_head
        qa_head = self.qa_head.module if hasattr(self.qa_head, "module") else self.qa_head

        state_dict = {
            "lang_encoder": lang_encoder.state_dict(),
            "point_encoder": point_encoder.state_dict(),
            "unified_encoder": unified_encoder.state_dict(),
            "ground_head": ground_head.state_dict(),
            "qa_head": qa_head.state_dict(),
        }
        self.saver.save_dict(state_dict, is_best=is_best)

    def set_model_state(self, state="train"):
        assert state in ["train", "eval"]
        torch.cuda.empty_cache()
        if state == "train":
            self.lang_encoder.train()
            self.point_encoder.train()
            self.unified_encoder.train()
            self.ground_head.train()
            self.qa_head.train()
        else:
            self.lang_encoder.eval()
            self.point_encoder.eval()
            self.unified_encoder.eval()
            self.ground_head.eval()
            self.qa_head.eval()
