from collections import OrderedDict
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
import logging

import torch
import torch.nn as nn

from cheem.models.utils import running_stat

from .wrappers import ContinualMoE, SimilaritySupernet
from .utils import running_stat


_logger = logging.getLogger(__name__)

class PromptWrapper(nn.Module):

    def __init__(self, backbone: nn.Module, task_idx, prompt_len, head_factory, transfer_prompts=False):
        super().__init__()

        # If the backbone is already set (e.g., by a subclass), skip initialization
        _backbone = getattr(self, "backbone", None)
        if _backbone is None:
            self.backbone = backbone
        # Same for heads
        _heads = getattr(self, "heads", None)
        if _heads is None:
            self.heads = head_factory(backbone)
        # And task_idx
        _task_idx = getattr(self, "task_idx", None)
        if _task_idx is None:
            self.task_idx = task_idx
        
        self.prompt_len = prompt_len
        self.prompts = nn.ParameterList([
            nn.Parameter(torch.zeros(1, self.prompt_len, backbone.num_features).uniform_(-0.01, 0.01), requires_grad=True)
            for _ in range(len(self.heads))
        ])
        
        self.transfer_prompts = transfer_prompts

    def set_task_idx(self, task_idx, train=False):
        self.task_idx = task_idx
        self.heads.requires_grad_(False)
        self.prompts.requires_grad_(False)
        if task_idx > 0 and train:
            self.heads[task_idx-1].requires_grad_(True)
            self.prompts[task_idx-1].requires_grad_(True)

    def _pos_embed(self, x):
        if self.backbone.no_embed_class:
            # deit-3, updated JAX (big vision)
            # position embedding does not overlap with class token, add then concat
            x = x + self.backbone.pos_embed
            if self.backbone.cls_token is not None and self.task_idx == 0:
                x = torch.cat((self.backbone.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        else:
            # original timm, JAX, and deit vit impl
            # pos_embed has entry for class token, concat then add
            if self.backbone.cls_token is not None:
                x = torch.cat((self.backbone.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
            
            pos_embed = self.backbone.pos_embed
            x = x + pos_embed
            x = self.backbone.pos_drop(x)
            if self.task_idx > 0:
                prompt = self.prompts[self.task_idx-1].expand(x.shape[0], -1, -1)
                x = torch.cat((prompt, x), dim=1)
        return x
    
    def forward_features(self, x):
        x = self.backbone.patch_embed(x)
        x = self._pos_embed(x)
        x = self.backbone.norm_pre(x)
        x = self.backbone.blocks(x)
        x = self.backbone.norm(x)
        return x

    def forward_head(self, x, pre_logits: bool = False):
        if self.backbone.global_pool:
            if self.backbone.global_pool == "avg":
                x = x[:, self.backbone.num_prefix_tokens+self.prompt_len:].mean(dim=1)
            elif self.task_idx > 0:
                x = x[:, :self.prompt_len, :].mean(dim=1)
            else:
                x = x[:, 0]
        x = self.backbone.fc_norm(x)
        return x if pre_logits else self.backbone.head(x)

    def forward(self, x, use_backbone_head=False):
        x = self.forward_features(x)
        _use_backbone_head = use_backbone_head or self.task_idx == 0
        pre_logits=not _use_backbone_head
        x = self.forward_head(x, pre_logits=pre_logits)

        if pre_logits and self.task_idx > 0:
            x = self.heads[self.task_idx-1](x)
        return x
        
    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):

        _state_dict = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
        # Remove the backbone parameters
        keys_wo_backbone = [key for key in _state_dict.keys() if ("backbone" not in key and "pos_embed" not in key)]

        return OrderedDict({key: _state_dict[key] for key in keys_wo_backbone})

    def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = False):
        # Note the change of default value of strict from True to False
        _ = super().load_state_dict(state_dict, strict)

        if self.transfer_prompts:
            self.copy_prompts()

        return _

    def copy_prompts(self):
        assert self.task_idx is not None
        if self.task_idx > 1:
            with torch.no_grad():
                # Load the prompts from the previous task
                self.prompts[self.task_idx-1].copy_(self.prompts[self.task_idx-2])


class PromptedContinualMoE(ContinualMoE):

    def __init__(
            self, 
            task_idx,
            backbone: nn.Module, 
            config: dict, 
            iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]], 
            head_factory: Callable[[nn.Module], nn.ModuleList], 
            op_factory_generator: Callable[[nn.Module, str], Callable], 
            prompt_len: int,
            transfer_prompts: bool = False,
            stat_funcs: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = None, 
            stat_update_function: Callable[[torch.Tensor, torch.Tensor, Any], torch.Tensor] = running_stat, 
            initial_expert_ids: str = "expert_0", 
            **kwargs):
        
        # Init the ContinualMoE first
        super().__init__(backbone, config, iter_backbone, head_factory, op_factory_generator, stat_funcs, stat_update_function, initial_expert_ids, **kwargs)

        self.prompt_len = prompt_len
        self.prompts = nn.ParameterList([
            nn.Parameter(torch.zeros(1, self.prompt_len, backbone.num_features).uniform_(-0.01, 0.01), requires_grad=True)
            for _ in range(len(self.heads))
        ])
        
        self.transfer_prompts = transfer_prompts
    
    def _pos_embed(self, x):
        if self.backbone.no_embed_class:
            # deit-3, updated JAX (big vision)
            # position embedding does not overlap with class token, add then concat
            x = x + self.backbone.pos_embed
            if self.backbone.cls_token is not None and self.task_idx == 0:
                x = torch.cat((self.backbone.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        else:
            # original timm, JAX, and deit vit impl
            # pos_embed has entry for class token, concat then add
            if self.backbone.cls_token is not None and self.task_idx == 0:
                x = torch.cat((self.backbone.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
            else:
                prompt = self.prompts[self.task_idx-1].expand(x.shape[0], -1, -1)
                x = torch.cat((prompt, x), dim=1)
            
            pos_embed = self.backbone.pos_embed
            x = x + pos_embed
            x = self.backbone.pos_drop(x)
        return x
    
    def forward_features(self, x):
        x = self.backbone.patch_embed(x)
        x = self._pos_embed(x)
        x = self.backbone.norm_pre(x)
        x = self.backbone.blocks(x)
        x = self.backbone.norm(x)
        return x

    def forward_head(self, x, pre_logits: bool = False):
        if self.backbone.global_pool:
            if self.backbone.global_pool == "avg":
                x = x[:, self.backbone.num_prefix_tokens+self.prompt_len:].mean(dim=1)
            elif self.task_idx > 0:
                x = x[:, :self.prompt_len, :].mean(dim=1)
            else:
                x = x[:, 0]
        x = self.backbone.fc_norm(x)
        return x if pre_logits else self.backbone.head(x)

    def _forward(self, x, use_backbone_head=False):
        x = self.forward_features(x)
        _use_backbone_head = use_backbone_head or self.task_idx == 0
        pre_logits=not _use_backbone_head
        x = self.forward_head(x, pre_logits=pre_logits)

        if pre_logits and self.task_idx > 0:
            x = self.heads[self.task_idx-1](x)
        return x

    def forward(self, x, expert_ids: int=None, pre_logits: bool=False, use_backbone_head=False, use_backbone_for_prompts=False, **kwargs):
        
        if use_backbone_for_prompts:
            # Override all expert ids and use the expert ids from ImageNet
            expert_ids = [self.task_to_expert_map[l][0] for l in range(len(self.task_to_expert_map))]
        
        self.forward_state_args = kwargs
        self._expert_ids_to_use = expert_ids

        # Use the PromptWrapper forward, which will call the backbone forward methods
        # The hooks are registered through the ContinualMoE wrapper
        x = self._forward(x, use_backbone_head=use_backbone_head)

        self.forward_state_args = None
        self._expert_ids_to_use = None

        return x

    def prompt_train_status(self, train=False):
        assert self.task_idx is not None
        self.prompts.requires_grad_(False)
        if self.task_idx > 0:
            self.prompts[self.task_idx-1].requires_grad_(train)

    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
        _state_dict = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
        # Remove the backbone parameters
        keys_wo_backbone = [key for key in _state_dict.keys() if ("backbone" not in key and "pos_embed" not in key)]

        return OrderedDict({key: _state_dict[key] for key in keys_wo_backbone})
    
    def load_state_dict(self, state_dict: OrderedDict[str, torch.Tensor], strict: bool = False):
        # Note the change of default value of strict from True to False
        _ = super().load_state_dict(state_dict, strict)

        if self.transfer_prompts:
            self.copy_prompts()

        return _
    
    def prompt_params(self):
        # Return the prompts and the heads
        assert self.task_idx is not None
        return [self.prompts[self.task_idx-1]] + list(self.heads[self.task_idx-1].parameters())
    
    def expert_params(self):
        # Return the parameters in the experts and the heads
        assert self.task_idx is not None
        parameters = []
        for l, experts in enumerate(self.experts):
            layer_expert = self.task_to_expert_map[l][self.task_idx]
            parameters = parameters + list(experts[layer_expert].parameters())
        return parameters + list(self.heads[self.task_idx-1].parameters())


class PromptedSimilaritySupernet(SimilaritySupernet):

    def __init__(
            self, 
            task_idx: int, 
            backbone: nn.Module, 
            config: dict, 
            iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]], 
            head_factory: Callable[[nn.Module], nn.ModuleList], 
            op_factory_generator, 
            sampler, 
            prompt_len: int,
            transfer_prompts: bool = False,
            stat_funcs: Callable[[torch.Tensor], torch.Tensor] = None, 
            stat_update_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = running_stat, 
            similarity_stat=None, 
            initial_expert_ids: str = "expert_0", 
            normalize_similarities: bool = True, 
            **kwargs):

        super().__init__(
            task_idx, 
            backbone, 
            config, 
            iter_backbone, 
            head_factory, 
            op_factory_generator, 
            sampler, 
            stat_funcs, 
            stat_update_function, 
            similarity_stat, 
            initial_expert_ids, 
            normalize_similarities, 
            **kwargs)
        
        self.prompt_len = prompt_len
        self.prompts = nn.ParameterList([
            nn.Parameter(torch.zeros(1, self.prompt_len, backbone.num_features).uniform_(-0.01, 0.01), requires_grad=True)
            for _ in range(len(self.heads))
        ])
        
        self.transfer_prompts = transfer_prompts
    
    def _pos_embed(self, x, task_idx):
        if self.backbone.no_embed_class:
            # deit-3, updated JAX (big vision)
            # position embedding does not overlap with class token, add then concat
            x = x + self.backbone.pos_embed
            if self.backbone.cls_token is not None:
                x = torch.cat((self.backbone.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        else:
            # original timm, JAX, and deit vit impl
            # pos_embed has entry for class token, concat then add
            if self.backbone.cls_token is not None and task_idx == 0:
                x = torch.cat((self.backbone.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
            else:
                prompt = self.prompts[task_idx-1].expand(x.shape[0], -1, -1)
                x = torch.cat((prompt, x), dim=1)
            
            pos_embed = self.backbone.pos_embed
            x = x + pos_embed
            x = self.backbone.pos_drop(x)
        return x
    
    def forward_features(self, x, task_idx):
        x = self.backbone.patch_embed(x)
        x = self._pos_embed(x, task_idx)
        x = self.backbone.norm_pre(x)
        x = self.backbone.blocks(x)
        x = self.backbone.norm(x)
        return x

    def forward_head(self, x, task_idx, pre_logits: bool = False):

        if self.backbone.global_pool:
            if self.backbone.global_pool == "avg":
                x = x[:, self.backbone.num_prefix_tokens+self.prompt_len:].mean(dim=1)
            elif task_idx > 0:
                x = x[:, :self.prompt_len, :].mean(dim=1)
            else:
                x = x[:, 0]
        x = self.backbone.fc_norm(x)
        return x if pre_logits else self.backbone.head(x)

    def _forward(self, x, task_idx, use_backbone_head=False):
        x = self.forward_features(x, task_idx)
        _use_backbone_head = use_backbone_head or self.task_idx == 0
        pre_logits=not _use_backbone_head
        x = self.forward_head(x, task_idx, pre_logits=pre_logits)

        if pre_logits and self.task_idx > 0:
            x = self.heads[self.task_idx-1](x)
        return x

    def forward(
            self, x, task_idx: int=None,
            expert_ids: list[int]=None, pre_logits: bool=False,
            seed: int=0, constant_sample_per_worker: bool=True,
            verbose: bool=False, worker_id: int=None, use_backbone_head=False,
            use_backbone_for_prompts=False,
            **kwargs):

        if use_backbone_for_prompts:
            # Override all expert ids and use the expert ids from ImageNet
            expert_ids = [self.task_to_expert_map[l][0] for l in range(len(self.task_to_expert_map))]

        if expert_ids is not None:
            # expert_ids will be defined when calculating the task_stats or
            # running the evolutionary search
            assert task_idx is None
        elif task_idx is not None:
            assert expert_ids is None
            expert_ids = [self.task_to_expert_map[l][task_idx] for l in range(len(self.task_to_expert_map))]
        else:
            assert self.samplers is not None, "The sampler must be defined when expert_ids are not defined."
            expert_ids = self._sample_nas_ops(seed=seed, constant_sample_per_worker=constant_sample_per_worker, worker_id=worker_id, **kwargs)
        
        self._expert_ids_to_use = expert_ids
        if verbose:
            _logger.info(f"[Worker ID: {worker_id}] NAS expert ids: {expert_ids}")

        self.forward_state_args = kwargs

        _task_idx = self.task_idx if task_idx is None else task_idx
        x = self._forward(x, _task_idx, use_backbone_head=use_backbone_head)

        # Reset
        self.forward_state_args = None
        self._expert_ids_to_use = None

        return x

    def prompt_train_status(self, train=False):
        assert self.task_idx is not None
        self.prompts.requires_grad_(False)
        if self.task_idx > 0:
            self.prompts[self.task_idx-1].requires_grad_(train)

    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
        _state_dict = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
        # Remove the backbone parameters
        keys_wo_backbone = [key for key in _state_dict.keys() if ("backbone" not in key and "pos_embed" not in key)]

        return OrderedDict({key: _state_dict[key] for key in keys_wo_backbone})
    
    def load_state_dict(self, state_dict: OrderedDict[str, torch.Tensor], strict: bool = False):
        # Note the change of default value of strict from True to False
        _ = super().load_state_dict(state_dict, strict)

        if self.transfer_prompts:
            self.copy_prompts()

        return _
    
    def copy_prompts(self):
        assert self.task_idx is not None
        if self.task_idx > 1:
            with torch.no_grad():
                # Load the prompts from the previous task
                self.prompts[self.task_idx-1].copy_(self.prompts[self.task_idx-2])
    
    def prompt_params(self):
        # Return the prompts and the heads
        assert self.task_idx is not None
        return [self.prompts[self.task_idx-1]] + list(self.heads[self.task_idx-1].parameters())
    
    def expert_params(self):
        # Return the parameters in the experts and the heads
        assert self.task_idx is not None
        parameters = []
        for l, layer_nas_ops in enumerate(self.nas_search_space):
            for nas_op in layer_nas_ops:
                parameters = parameters + list(self.experts[l][nas_op].parameters())
        return parameters + list(self.heads[self.task_idx-1].parameters())
