import copy
import torch
import torch.nn as nn
import numpy as np
import torchvision
from typing import Dict, Tuple, Union
import albumentations as A
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.vision.crop_randomizer import CropRandomizer
import kornia.augmentation as K

class ImageObsEncoder(ModuleAttrMixin):
    def __init__(self,
            shape_meta: dict,
            rgb_model: Union[nn.Module, Dict[str,nn.Module]],
            resize_shape: Union[Tuple[int,int], Dict[str,tuple], None]=None,
            crop_shape: Union[Tuple[int,int], Dict[str,tuple], None]=None,
            random_crop: bool=True,
            use_group_norm: bool=False,
            share_rgb_model: bool=False,
            imagenet_norm: bool=False,
            use_color_jitter: bool = False,
            color_jitter_params: dict = None,
            transforms: list = None,
        ):
        super().__init__()

        rgb_keys = []
        low_dim_keys = []
        key_model_map = nn.ModuleDict()
        train_transform_map = nn.ModuleDict()
        eval_transform_map = nn.ModuleDict()
        key_shape_map = {}

        if share_rgb_model:
            assert isinstance(rgb_model, nn.Module)
            key_model_map["rgb"] = rgb_model

        obs_shape_meta = shape_meta["obs"]

        for key, attr in obs_shape_meta.items():
            shape = tuple(attr["shape"])
            type = attr.get("type", "low_dim")
            key_shape_map[key] = shape

            if type == "rgb":
                rgb_keys.append(key)

                this_model = None
                if not share_rgb_model:
                    if isinstance(rgb_model, dict):
                        this_model = rgb_model[key]
                    else:
                        this_model = copy.deepcopy(rgb_model)

                if this_model is not None:
                    if use_group_norm:
                        this_model = replace_submodules(
                            root_module=this_model,
                            predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                            func=lambda x: nn.GroupNorm(
                                num_groups=x.num_features//16,
                                num_channels=x.num_features)
                        )
                    key_model_map[key] = this_model

                C, H, W = shape
                input_shape = shape

                train_list = []
                eval_list = []

                if resize_shape is not None:
                    if isinstance(resize_shape, dict):
                        rh, rw = resize_shape[key]
                    else:
                        rh, rw = resize_shape

                    train_list.append(torchvision.transforms.Resize((rh, rw)))
                    eval_list.append(torchvision.transforms.Resize((rh, rw)))

                    input_shape = (C, rh, rw)

                if crop_shape is not None:
                    if isinstance(crop_shape, dict):
                        h, w = crop_shape[key]
                    elif isinstance(crop_shape, tuple):
                        h, w = crop_shape
                    else:
                        h, w = int(H * crop_shape), int(W * crop_shape)
                    if random_crop:
                        this_randomizer = CropRandomizer(
                            input_shape=input_shape,
                            crop_height=h,
                            crop_width=w,
                            num_crops=1,
                            pos_enc=False
                        )
                        train_list.append(this_randomizer)
                        eval_list.append(this_randomizer)

                elif random_crop and transforms is not None and transforms[0].type == "RandomResizedCrop":
                    cfg = transforms[0]

                    train_list.append(
                        K.RandomResizedCrop(
                            size=(input_shape[1], input_shape[2]),
                            scale=cfg.scale,
                            ratio=cfg.ratio,
                            p=cfg.crop_prob
                        )
                    )

                if use_color_jitter:
                    p = color_jitter_params or {}
                    train_list.append(
                        torchvision.transforms.ColorJitter(
                            brightness=p.get("brightness", 0.0),
                            contrast=p.get("contrast", 0.0),
                            saturation=p.get("saturation", 0.0),
                            hue=p.get("hue", 0.0),
                        )
                    )

                if imagenet_norm:
                    norm = torchvision.transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]
                    )
                    train_list.append(norm)
                    eval_list.append(norm)

                train_transform_map[key] = nn.Sequential(*train_list)
                eval_transform_map[key] = nn.Sequential(*eval_list)

            elif type == "low_dim":
                low_dim_keys.append(key)
            else:
                raise RuntimeError(f"Unsupported obs type: {type}")

        self.shape_meta = shape_meta
        self.key_model_map = key_model_map
        self.train_transform_map = train_transform_map
        self.eval_transform_map = eval_transform_map
        self.share_rgb_model = share_rgb_model
        self.rgb_keys = sorted(rgb_keys)
        self.low_dim_keys = sorted(low_dim_keys)
        self.key_shape_map = key_shape_map

    def forward(self, obs_dict):
        batch_size = None
        features = []

        if self.share_rgb_model:
            imgs = []
            for key in self.rgb_keys:
                img = obs_dict[key]
                if img.max() > 1.0:

                    img = img / 255.0
                batch_size = batch_size or img.shape[0]
                assert img.shape[1:] == self.key_shape_map[key]

                if self.training:
                    img = self.train_transform_map[key](img)
                else:
                    img = self.eval_transform_map[key](img)

                imgs.append(img)

            imgs = torch.cat(imgs, dim=0)
            feat = self.key_model_map["rgb"](imgs)

            feat = feat.reshape(-1, batch_size, *feat.shape[1:])
            feat = torch.moveaxis(feat, 0, 1)
            feat = feat.reshape(batch_size, -1)
            features.append(feat)

        else:
            for key in self.rgb_keys:
                img = obs_dict[key]
                batch_size = batch_size or img.shape[0]

                if self.training:
                    img = self.train_transform_map[key](img)
                else:
                    img = self.eval_transform_map[key](img)

                feat = self.key_model_map[key](img)
                features.append(feat)

        for key in self.low_dim_keys:
            data = obs_dict[key]
            batch_size = batch_size or data.shape[0]
            features.append(data)

        return torch.cat(features, dim=-1)

    @torch.no_grad()
    def output_shape(self):
        obs_shape_meta = self.shape_meta["obs"]
        example = {}
        for key, attr in obs_shape_meta.items():
            shape = tuple(attr["shape"])
            example[key] = torch.zeros((1,)+shape, dtype=self.dtype, device=self.device)
        output = self.forward(example)
        return output.shape[1:]
