import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

from stream3r.models.aggregator import Aggregator
from stream3r.heads.camera_head import CameraHead
from stream3r.heads.dpt_head import DPTHead
from stream3r.heads.track_head import TrackHead
from transformers.file_utils import ModelOutput
from typing import Optional, Tuple, List, Any
from dataclasses import dataclass


@dataclass
class Stream3ROutput(ModelOutput):
    ress: Optional[List[dict]] = None
    views: Optional[torch.Tensor] = None


class Stream3R(nn.Module, PyTorchModelHubMixin):
    def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
        super().__init__()

        self.aggregator = Aggregator(
            img_size=img_size, patch_size=patch_size, embed_dim=embed_dim
        )
        self.camera_head = CameraHead(dim_in=2 * embed_dim)
        self.point_head = DPTHead(
            dim_in=2 * embed_dim,
            output_dim=4,
            activation="inv_log",
            conf_activation="expp1",
        )
        self.depth_head = DPTHead(
            dim_in=2 * embed_dim,
            output_dim=2,
            activation="exp",
            conf_activation="expp1",
        )
        self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)

    def forward(
        self,
        views,
        query_points: torch.Tensor = None,
        history_info: Optional[dict] = None,
        past_key_values=None,
        use_cache=False,
        past_frame_idx=0,
    ):

        images = torch.stack([view["img"] for view in views], dim=0).permute(
            1, 0, 2, 3, 4
        )

        if len(images.shape) == 4:
            images = images.unsqueeze(0)
        if query_points is not None and len(query_points.shape) == 2:
            query_points = query_points.unsqueeze(0)

        if history_info is None:
            history_info = {"token": None}

        aggregated_tokens_list, patch_start_idx = self.aggregator(images)
        predictions = {}
        with torch.cuda.amp.autocast(enabled=False):
            if self.camera_head is not None:
                pose_enc_list = self.camera_head(aggregated_tokens_list)
                predictions["pose_enc"] = pose_enc_list[-1]

            if self.depth_head is not None:
                depth, depth_conf = self.depth_head(
                    aggregated_tokens_list,
                    images=images,
                    patch_start_idx=patch_start_idx,
                )
                predictions["depth"] = depth
                predictions["depth_conf"] = depth_conf

            if self.point_head is not None:
                pts3d, pts3d_conf = self.point_head(
                    aggregated_tokens_list,
                    images=images,
                    patch_start_idx=patch_start_idx,
                )
                predictions["world_points"] = pts3d
                predictions["world_points_conf"] = pts3d_conf

            if self.track_head is not None and query_points is not None:
                track_list, vis, conf = self.track_head(
                    aggregated_tokens_list,
                    images=images,
                    patch_start_idx=patch_start_idx,
                    query_points=query_points,
                )
                predictions["track"] = track_list[-1]
                predictions["vis"] = vis
                predictions["conf"] = conf
            predictions["images"] = images
            B, S = images.shape[:2]
            ress = []
            for s in range(S):
                res = {
                    "pts3d_in_other_view": predictions["world_points"][:, s],
                    "conf": predictions["world_points_conf"][:, s],
                    "depth": predictions["depth"][:, s],
                    "depth_conf": predictions["depth_conf"][:, s],
                    "camera_pose": predictions["pose_enc"][:, s, :],
                    **(
                        {"valid_mask": views[s]["valid_mask"]}
                        if "valid_mask" in views[s]
                        else {}
                    ),
                    **(
                        {
                            "track": predictions["track"][:, s],
                            "vis": predictions["vis"][:, s],
                            "track_conf": predictions["conf"][:, s],
                        }
                        if "track" in predictions
                        else {}
                    ),
                }
                ress.append(res)
            return Stream3ROutput(ress=ress, views=views)

    def inference(
        self, frames, query_points: torch.Tensor = None, past_key_values=None
    ):
        past_key_values = [None] * self.aggregator.depth
        past_key_values_camera = [None] * self.camera_head.trunk_depth

        all_ress = []
        processed_frames = []

        for i, frame in enumerate(frames):
            images = frame["img"].unsqueeze(0)
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                aggregator_output = self.aggregator(
                    images,
                    past_key_values=past_key_values,
                    use_cache=True,
                    past_frame_idx=i,
                )

                if isinstance(aggregator_output, tuple) and len(aggregator_output) == 3:
                    (
                        aggregated_tokens,
                        patch_start_idx,
                        past_key_values,
                    ) = aggregator_output
                else:
                    aggregated_tokens, patch_start_idx = aggregator_output

            with torch.cuda.amp.autocast(enabled=False):
                if self.camera_head is not None:
                    pose_enc, past_key_values_camera = self.camera_head(
                        aggregated_tokens,
                        past_key_values_camera=past_key_values_camera,
                        use_cache=True,
                    )
                    pose_enc = pose_enc[-1]
                    camera_pose = pose_enc[:, 0, :]

                if self.depth_head is not None:
                    depth, depth_conf = self.depth_head(
                        aggregated_tokens,
                        images=images,
                        patch_start_idx=patch_start_idx,
                    )
                    depth = depth[:, 0]
                    depth_conf = depth_conf[:, 0]

                if self.point_head is not None:
                    pts3d, pts3d_conf = self.point_head(
                        aggregated_tokens,
                        images=images,
                        patch_start_idx=patch_start_idx,
                    )
                    pts3d = pts3d[:, 0]
                    pts3d_conf = pts3d_conf[:, 0]

                if self.track_head is not None and query_points is not None:
                    track_list, vis, conf = self.track_head(
                        aggregated_tokens,
                        images=images,
                        patch_start_idx=patch_start_idx,
                        query_points=query_points,
                    )
                    track = track_list[-1][:, 0]
                    query_points = track
                    vis = vis[:, 0]
                    track_conf = conf[:, 0]

            all_ress.append(
                {
                    "pts3d_in_other_view": pts3d,
                    "conf": pts3d_conf,
                    "depth": depth,
                    "depth_conf": depth_conf,
                    "camera_pose": camera_pose,
                    **(
                        {"valid_mask": frame["valid_mask"]}
                        if "valid_mask" in frame
                        else {}
                    ),
                    **(
                        {"track": track, "vis": vis, "track_conf": track_conf}
                        if query_points is not None
                        else {}
                    ),
                }
            )
            processed_frames.append(frame)

        output = Stream3ROutput(ress=all_ress, views=processed_frames)
        return output
