from torch import Tensor
from torch.nn import Conv3d, ConvTranspose3d, ReLU, Module, Sequential
from pcdet.utils.spconv_utils import spconv
from functools import partial
import torch.nn as nn

import torch
import os

class Inversion_Model_Convout_CLS(Module):


    def __init__(
        self,
    ):
        super(Inversion_Model_Convout_CLS, self).__init__()

        norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)

        self.upsample_1 = spconv.SparseSequential(
            spconv.SparseConvTranspose3d(
                128,
                64,
                (3, 1, 1),
                (2, 1, 1),
                bias=False,
            ),
            norm_fn(64),
            ReLU(),
            spconv.SubMConv3d(
                64,
                64,
                (3, 3, 3),
                (1, 1, 1),
                (1, 1, 1),
                bias=False,
                indice_key="subm4_inv",
            ),
            norm_fn(64),
            ReLU(),
            spconv.SubMConv3d(
                64,
                64,
                (3, 3, 3),
                (1, 1, 1),
                (1, 1, 1),
                bias=False,
                indice_key="subm4_inv",
            ),
            norm_fn(64),
            ReLU(),
        )

        self.upsample_2 = spconv.SparseSequential(
            spconv.SparseConvTranspose3d(
                64,
                64,
                (3, 2, 2),
                (2, 2, 2),
                bias=False,
            ),
            norm_fn(64),
            ReLU(),
            spconv.SubMConv3d(
                64,
                64,
                (3, 3, 3),
                (1, 1, 1),
                (1, 1, 1),
                bias=False,
                indice_key="subm3_inv",
            ),
            norm_fn(64),
            ReLU(),
            spconv.SubMConv3d(
                64,
                64,
                (3, 3, 3),
                (1, 1, 1),
                (1, 1, 1),
                bias=False,
                indice_key="subm3_inv",
            ),
            norm_fn(64),
            ReLU(),
        )

        self.upsample_3 = spconv.SparseSequential(
            spconv.SparseConvTranspose3d(
                64,
                32,
                (2, 2, 2),
                (2, 2, 2),
                bias=False,
            ),
            norm_fn(32),
            ReLU(),
            spconv.SubMConv3d(
                32,
                32,
                (3, 3, 3),
                (1, 1, 1),
                (1, 1, 1),
                bias=False,
                indice_key="subm2_inv",
            ),
            norm_fn(32),
            ReLU(),
            spconv.SubMConv3d(
                32,
                32,
                (3, 3, 3),
                (1, 1, 1),
                (1, 1, 1),
                bias=False,
                indice_key="subm2_inv",
            ),
            norm_fn(32),
            ReLU(),
        )

        self.upsample_4 = spconv.SparseSequential(
            spconv.SparseConvTranspose3d(
                32,
                16,
                (2, 2, 2),
                (2, 2, 2),
                bias=False,
            ),
            norm_fn(16),
            ReLU(),
        )

        self.cls = spconv.SparseSequential(
            spconv.SubMConv3d(
                16,
                1,
                (3, 3, 3),
                (1, 1, 1),
                (1, 1, 1),
                bias=False,
                indice_key="subm1_cls",
            ),
        )

    def forward(self, x: Tensor) -> Tensor:
        x_u1 = self.upsample_1(x)

        x_u2 = self.upsample_2(x_u1)

        x_u3 = self.upsample_3(x_u2)
        x_u3 = spconv.SparseConvTensor(
            features=x_u3.features[x_u3.indices[:,1]!=21],
            indices=x_u3.indices[x_u3.indices[:,1]!=21].int(),
            spatial_shape=[x_u3.spatial_shape[0]-1, x_u3.spatial_shape[1], x_u3.spatial_shape[2]],
            batch_size=x_u3.dense().shape[0],
        )

        x_u4 = self.upsample_4(x_u3)
        x_u4 = spconv.SparseConvTensor(
            features=x_u4.features[x_u4.indices[:,1]!=41],
            indices=x_u4.indices[x_u4.indices[:,1]!=41].int(),
            spatial_shape=[x_u4.spatial_shape[0]-1, x_u4.spatial_shape[1], x_u4.spatial_shape[2]],
            batch_size=x_u4.dense().shape[0],
        )

        cls_result = self.cls(x_u4)

        return cls_result


    def _load_state_dict(self, model_state_disk, *, strict=True):
        state_dict = self.state_dict()  # local cache of state_dict

        update_model_state = {}
        for key, val in model_state_disk.items():
            if key in state_dict and state_dict[key].shape == val.shape:
                update_model_state[key] = val
                # logger.info('Update weight %s: %s' % (key, str(val.shape)))

        if strict:
            self.load_state_dict(update_model_state)
        else:
            state_dict.update(update_model_state)
            self.load_state_dict(state_dict)
        return state_dict, update_model_state

    def load_params_from_file(self, filename, logger, to_cpu=False):
        if not os.path.isfile(filename):
            raise FileNotFoundError

        logger.info(
            "==> Loading parameters from checkpoint %s to %s"
            % (filename, "CPU" if to_cpu else "GPU")
        )
        loc_type = torch.device("cpu") if to_cpu else None
        checkpoint = torch.load(filename, map_location=loc_type)
        model_state_disk = checkpoint["model_state"]

        version = checkpoint.get("version", None)
        if version is not None:
            logger.info("==> Checkpoint trained from version: %s" % version)

        state_dict, update_model_state = self._load_state_dict(
            model_state_disk, strict=False
        )

        for key in state_dict:
            if key not in update_model_state:
                logger.info(
                    "Not updated weight %s: %s" % (key, str(state_dict[key].shape))
                )

        logger.info(
            "==> Done (loaded %d/%d)" % (len(update_model_state), len(state_dict))
        )

    def load_params_with_optimizer(
        self, filename, to_cpu=False, optimizer=None, logger=None
    ):
        if not os.path.isfile(filename):
            raise FileNotFoundError

        logger.info(
            "==> Loading parameters from checkpoint %s to %s"
            % (filename, "CPU" if to_cpu else "GPU")
        )
        loc_type = torch.device("cpu") if to_cpu else None
        checkpoint = torch.load(filename, map_location=loc_type)
        epoch = checkpoint.get("epoch", -1)
        it = checkpoint.get("it", 0.0)

        self._load_state_dict(checkpoint["model_state"], strict=True)

        if optimizer is not None:
            if (
                "optimizer_state" in checkpoint
                and checkpoint["optimizer_state"] is not None
            ):
                logger.info(
                    "==> Loading optimizer parameters from checkpoint %s to %s"
                    % (filename, "CPU" if to_cpu else "GPU")
                )
                optimizer.load_state_dict(checkpoint["optimizer_state"])
            else:
                assert filename[-4] == ".", filename
                src_file, ext = filename[:-4], filename[-3:]
                optimizer_filename = "%s_optim.%s" % (src_file, ext)
                if os.path.exists(optimizer_filename):
                    optimizer_ckpt = torch.load(
                        optimizer_filename, map_location=loc_type
                    )
                    optimizer.load_state_dict(optimizer_ckpt["optimizer_state"])

        if "version" in checkpoint:
            print("==> Checkpoint trained from version: %s" % checkpoint["version"])
        logger.info("==> Done")

        return it, epoch
