import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, einsum
from einops.layers.torch import Rearrange

from mia.models.layers import (
    FixedFourierFeatures,
    LearnableFourierFeatures,
    PreNorm,
    Attention,
    Transformer,
    Perceiver,
    LatentReshape2D,
    LatentReshape1D,
    Mlp,
)

from mia.models.inr import (
    Siren,
    FourierFeatureINR,
    RandomFourierFeatureINR,
    BasicINR,
    GeneralizableINR,
)

from mia.utils import (
    get_input_dims,
    get_output_dims,
    get_out_bias,
    get_input_range
)

class CrossModalTransferModel(nn.Module):
    def __init__(
        self,
        args,
        modes,
        latent_spatial_shapes,
        inr_dict,
        context_encoder_dict,
        grad_encoder_dict,
        meta_sgd_dict,
    ) -> None:
        super().__init__()

        self.args = args
        self.modes = modes
        self.num_modes = len(modes)
        self.dim_hidden = inr_dict['dim_hidden']
        self.num_layers = inr_dict['num_layers']

        self.inr_type = inr_dict['inr_type']
        self.inr_dict = inr_dict
        self.latent_spatial_shapes = latent_spatial_shapes

        self.latent_dims = {
            mode : self.dim_hidden for mode in modes
        } # if latent_dims is None else latent_dims

        self.context_encoder_dict = context_encoder_dict
        self.grad_encoder_dict = grad_encoder_dict
        self.meta_sgd_dict = meta_sgd_dict

        # Init INR layers
        # ---------------------------------------------------------------------------------------

        self.dims_in = get_input_dims(modes)
        self.dims_out = get_output_dims(modes)

        self.inr = nn.ModuleDict()

        for mode in self.modes:
            self.inr[mode] = GeneralizableINR(
                dim_in = self.dims_in[mode],
                dim_out = self.dims_out[mode],
                dim_hidden = self.dim_hidden,
                num_layers = self.num_layers,
                rank = latent_spatial_shapes[mode],
                ff_dim = inr_dict['ff_dim'],
                sigma = inr_dict['sigma'],
                use_bias = True,
                out_bias = get_out_bias(mode),
            )

        self.modulate_dim = self.dim_hidden
        self.latent_shapes = latent_spatial_shapes
        self.latent_prior_embeds = nn.ParameterDict()

        for mode in self.modes:
            self.latent_prior_embeds[mode] = nn.Parameter(
                torch.randn(1, self.latent_shapes[mode], self.latent_dims[mode]) * 0.2, requires_grad = True
            )

        # meta sgd
        # ---------------------------------------------------------------------------------------
        lr_init = meta_sgd_dict['inner_lr_init']
        self.meta_lr = nn.ParameterDict()

        for mode in self.modes:
            if grad_encoder_dict['use_alfa']:
                self.meta_lr[mode] = nn.Parameter(torch.zeros(1, self.latent_shapes[mode], 1) + lr_init, requires_grad = meta_sgd_dict['use_meta_sgd'])
            else:
                self.meta_lr[mode] = nn.Parameter(torch.zeros(1, self.latent_shapes[mode], self.latent_dims[mode]) + lr_init, requires_grad = meta_sgd_dict['use_meta_sgd'])
            if lr_init == 0:
                nn.init.uniform_(self.meta_lr[mode], 0.005, 1.0)

        # uncertainty-based loss weighting
        if self.args.loss_weight_mode in ['uncertainty']:
            self.logvars = nn.ParameterDict()
            for mode in self.modes:
                self.logvars[mode] = nn.Parameter(torch.zeros(1,), requires_grad=True)

        # Init context encoder layers
        # ---------------------------------------------------------------------------------------
        #if context_encoder_dict['depth'] > 0:
        self.context_encoder = self.context_pooler = False
        if context_encoder_dict['um_depth'] + context_encoder_dict['mm_depth'] > 0:
            self.context_encoder_type = context_encoder_dict['type']
            if self.context_encoder_type in ['perceiver', 'perceiver-nopool']:
                assert context_encoder_dict['um_depth'] > 0

            self.context_encoder = self.context_pooler = True
            self.context_encoder_to_embeds = nn.ModuleDict()
            self.context_encoder_pos_embeds = nn.ParameterDict()
            self.context_encoder_embeds_dropout = nn.Dropout(context_encoder_dict['embed_dropout'])
            self.context_encoder_pos_embed_type = context_encoder_dict['pos_embed_type']


            for mode in modes:
                if self.context_encoder_pos_embed_type in ['fixed', 'learned']:
                    self.context_encoder_to_embeds[mode] = nn.Linear(self.dims_out[mode], context_encoder_dict['dim'])
                    if self.context_encoder_pos_embed_type == 'fixed':
                        self.context_encoder_pos_embeds[mode] = FixedFourierFeatures(context_encoder_dict['dim'] // self.dims_in[mode])
                    else:
                        self.context_encoder_pos_embeds[mode] = LearnableFourierFeatures(self.dims_in[mode], context_encoder_dict['dim'])
                elif self.context_encoder_pos_embed_type in ['concat']:
                    self.context_encoder_to_embeds[mode] = nn.Linear(self.dims_out[mode] + self.dims_in[mode], context_encoder_dict['dim'])
                    self.context_encoder_pos_embeds[mode] = None
                else:
                    raise


            self.context_encoder_um = nn.ModuleDict()
            for mode in self.modes:
                if self.context_encoder_type == 'transformer':
                    self.context_encoder_um[mode] = Transformer(
                        dim = context_encoder_dict['dim'],
                        depth = context_encoder_dict['um_depth'],
                        heads = context_encoder_dict['heads'],
                        dim_head = context_encoder_dict['dim_head'],
                        mlp_dim = int(context_encoder_dict['dim'] * context_encoder_dict['mlp_ratio']),
                        dropout = context_encoder_dict['dropout'],
                    )
                elif self.context_encoder_type in ['perceiver', 'perceiver-nopool']:
                    self.context_encoder_um[mode] = Perceiver(
                        dim = context_encoder_dict['dim'],
                        context_dim = context_encoder_dict['dim'],
                        num_querys = self.latent_shapes[mode],
                        depth = context_encoder_dict['um_depth'],
                        heads = context_encoder_dict['heads'],
                        dim_head = context_encoder_dict['dim_head'],
                        mlp_dim = int(context_encoder_dict['dim'] * context_encoder_dict['mlp_ratio']),
                        dropout = context_encoder_dict['dropout'],
                    )

            self.context_encoder_mm = Transformer(
                dim = context_encoder_dict['dim'],
                depth = context_encoder_dict['mm_depth'],
                heads = context_encoder_dict['heads'],
                dim_head = context_encoder_dict['dim_head'],
                mlp_dim = int(context_encoder_dict['dim'] * context_encoder_dict['mlp_ratio']),
                dropout = context_encoder_dict['dropout'],
            )

            self.context_pooler_pos_embeds = nn.ParameterDict()
            self.context_pooler_pos_embed_type = context_encoder_dict['pooler_pos_embed_type']


            if self.context_pooler_pos_embed_type in ['fixed']:
                for mode in self.modes:
                    self.context_pooler_pos_embeds[mode] = self._create_fourier_embeds(context_encoder_dict['dim'], mode)
            elif self.context_pooler_pos_embed_type in ['learned']:
                for mode in self.modes:
                    self.context_pooler_pos_embeds[mode] = nn.Parameter(torch.randn(self.latent_shapes[mode], context_encoder_dict['dim']) * 0.2, requires_grad = True)
            else:
                raise

            self.context_pooler_latent_to_token = nn.ModuleDict()
            self.context_pooler_token_to_latent = nn.ModuleDict()

            for mode in self.modes:
                self.context_pooler_latent_to_token[mode] = nn.Linear(self.latent_dims[mode], context_encoder_dict['dim'])
                self.context_pooler_token_to_latent[mode] = nn.Linear(context_encoder_dict['dim'], self.latent_dims[mode])

            self.context_pooler = Pooler(
                dim = context_encoder_dict['dim'],
                context_dim = context_encoder_dict['dim'],
                depth = context_encoder_dict['pooler_depth'],
                heads = context_encoder_dict['heads'],
                dim_head = context_encoder_dict['dim_head'],
                dropout = context_encoder_dict['dropout'],
            )


        # Init grad encoder layers
        # ---------------------------------------------------------------------------------------
        self.grad_encoder_mm_attn_type = grad_encoder_dict['mm_attn_type']

        self.alfa = False
        self.grad_encoder = False
        if grad_encoder_dict['um_depth'] + grad_encoder_dict['mm_depth'] > 0:
            self.grad_encoder = True

            grad_encoder_token_shapes = {}
            for mode in self.modes:
                grad_encoder_token_shapes[mode] = (self.latent_shapes[mode], grad_encoder_dict['dim'])

            if self.grad_encoder_mm_attn_type in ['spatial']:
                # (bsz, n, ld), (bsz, n, ld) -> (bsz, m*n, ld) -> (bsz, m*n, dim)
                self.grad_encoder_mm_cat_dim = 1
            else:
                raise

            self.grad_encoder_projection_mlp = nn.ModuleDict()
            self.grad_encoder_state_to_gradient = nn.ModuleDict()
            self.grad_encoder_pos_embed_type = grad_encoder_dict['pos_embed_type']
            self.grad_encoder_fuser = nn.ModuleDict() if grad_encoder_dict['use_fuser'] else False

            self.grad_encoder_grad_ln = nn.ModuleDict()
            if grad_encoder_dict['use_latent']:
                self.grad_encoder_latent_ln = nn.ModuleDict()

            for mode in self.modes:
                num_channels = self.latent_dims[mode]
                if grad_encoder_dict['use_latent']:
                    num_channels += self.latent_dims[mode]
                    self.grad_encoder_latent_ln[mode] = nn.LayerNorm(self.latent_dims[mode], eps = 1e-5)

                self.grad_encoder_grad_ln[mode] = nn.LayerNorm(self.latent_dims[mode], eps = 1e-5)

                self.grad_encoder_projection_mlp[mode] = nn.Sequential(
                    Mlp(num_channels, grad_encoder_dict['dim'], grad_encoder_dict['dim'], depth = grad_encoder_dict['projection_mlp_depth']),
                )

                if grad_encoder_dict['use_fuser']:
                    self.grad_encoder_fuser[mode] = nn.Sequential(
                        Rearrange('b n d -> b d n'),
                        nn.GroupNorm(
                            num_groups = 3,
                            num_channels = grad_encoder_dict['dim']*3,
                        ),
                        Rearrange('b d n -> b n d'),
                        Mlp(
                            in_features = grad_encoder_dict['dim'] * 3,
                            hidden_features = grad_encoder_dict['dim'],
                            out_features = grad_encoder_dict['dim'],
                            depth = grad_encoder_dict['depth_fuser'],
                        ),
                        nn.LayerNorm(grad_encoder_dict['dim']),
                    )

                self.grad_encoder_state_to_gradient[mode] = nn.Linear(
                    grad_encoder_dict['dim'],
                    self.latent_dims[mode],
                )

            self.grad_encoder_pos_embeds = nn.ParameterDict()
            if self.grad_encoder_pos_embed_type in ['fixed']:
                for mode in self.modes:
                    self.grad_encoder_pos_embeds[mode] = self._create_fourier_embeds(grad_encoder_dict['dim'], mode)
            elif self.grad_encoder_pos_embed_type in ['learned']:
                for mode in self.modes:
                    self.grad_encoder_pos_embeds[mode] = nn.Parameter(torch.randn(self.latent_shapes[mode], grad_encoder_dict['dim']) * 0.2, requires_grad = True)

            self.grad_encoder_um = nn.ModuleDict()
            for mode in self.modes:
                self.grad_encoder_um[mode] = Transformer(
                    dim = grad_encoder_dict['dim'],
                    depth = grad_encoder_dict['um_depth'],
                    heads = grad_encoder_dict['heads'],
                    dim_head = grad_encoder_dict['dim_head'],
                    mlp_dim = int(grad_encoder_dict['dim'] * grad_encoder_dict['mlp_ratio']),
                    dropout = grad_encoder_dict['dropout'],
                )
            self.grad_encoder_mm = Transformer(
                dim = grad_encoder_dict['dim'],
                depth = grad_encoder_dict['mm_depth'],
                heads = grad_encoder_dict['heads'],
                dim_head = grad_encoder_dict['dim_head'],
                mlp_dim = int(grad_encoder_dict['dim'] * grad_encoder_dict['mlp_ratio']),
                dropout = grad_encoder_dict['dropout'],
            )
        elif grad_encoder_dict['use_alfa']:
            self.alfa = nn.ModuleDict()
            self.beta_init_dict = nn.ParameterDict()
            for mode in self.modes:
                input_dim = self.latent_shapes[mode] * 2
                if grad_encoder_dict['dim_alfa'] > 0:
                    hidden_dim = grad_encoder_dict['dim_alfa']
                else:
                    hidden_dim = input_dim
                assert grad_encoder_dict['depth_alfa'] >= 2
                alfa = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
                for _ in range(grad_encoder_dict['depth_alfa'] - 2):
                    alfa += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]
                alfa += [nn.Linear(hidden_dim, input_dim)]
                self.alfa[mode] = nn.Sequential(*alfa)
                self.beta_init_dict[mode] = nn.Parameter(torch.ones(1, self.latent_shapes[mode], 1), requires_grad = meta_sgd_dict['use_meta_sgd'])

    def get_inr_params(self):
        non_inr_keywords = ['context_enc', 'context_pool', 'grad_enc', 'logvars']

        params = {}
        for k, v in dict(self.named_parameters()).items():
            non_inr_keywords_exist = []
            for non_inr_keyword in non_inr_keywords:
                non_inr_keywords_exist += [non_inr_keyword in k]
            if sum(non_inr_keywords_exist) == 0:
                params[k] = v
        return params

    def get_non_inr_params(self):
        non_inr_keywords = ['context_enc', 'context_pool', 'grad_enc']

        params = {}
        for k, v in dict(self.named_parameters()).items():
            non_inr_keywords_exist = []
            for non_inr_keyword in non_inr_keywords:
                non_inr_keywords_exist += [non_inr_keyword in k]
            if sum(non_inr_keywords_exist) > 0:
                # print('non_inr_params:', k)
                params[k] = v
        return params

    def get_logvars(self):
        non_inr_keywords = ['logvars']

        params = {}
        for k, v in dict(self.named_parameters()).items():
            non_inr_keywords_exist = []
            for non_inr_keyword in non_inr_keywords:
                non_inr_keywords_exist += [non_inr_keyword in k]
            if sum(non_inr_keywords_exist) > 0:
                # print('logvars:', k)
                params[k] = v
        return params

    def get_parameters(self, keys=None):
        if keys is None:
            params = [ v for k, v in self.named_parameters() ]
        else:
            if isinstance(keys, (list, tuple)):
                params = [ v for k, v in self.named_parameters() if len([key for key in keys if key in k]) > 0]
            elif isinstance(keys, str):
                params = [ v for k, v in self.named_parameters() if keys in k]

        return params

    def _create_fourier_embeds(self, dim, mode):
        w = torch.exp(torch.linspace(0, 8, dim // 2 // self.dims_in[mode]))
        coords = []
        input_range = get_input_range(mode)
        for dim_idx in range(self.dims_in[mode]):
            if self.latent_spatial_shapes[mode] == 1:
                coords.append(torch.linspace(-0, 0, self.latent_spatial_shapes[mode]))
            else:
                coords.append(torch.linspace(-input_range, input_range, self.latent_spatial_shapes[mode]))

        coords = torch.stack(torch.meshgrid(*coords, indexing='ij'), dim = -1)
        coords = einsum(coords, w, '... d, fdim -> ... d fdim').view(*coords.shape[:-1], -1)
        coords = torch.cat([torch.cos(torch.pi * coords), torch.sin(torch.pi * coords)], dim = -1)
        coords = coords.reshape(1, -1, dim)
        coords = nn.Parameter(coords, requires_grad = False)
        return coords

    def embed(self, xs, ys, mode):
        if self.context_encoder_pos_embed_type in ['fixed', 'learned']:
            tokens = self.context_encoder_to_embeds[mode](ys)
            tokens += self.context_encoder_pos_embeds[mode](xs)

        else:
            tokens = torch.cat([xs, ys], dim = -1)
            tokens = self.context_encoder_to_embeds[mode](tokens)

        return tokens

    def init_latent(self, bsz):
        latent_prior_dict = {}
        for mode in self.modes:
            latent_prior_dict[mode] = repeat(self.latent_prior_embeds[mode], '1 ... -> bsz ...', bsz = bsz)

        return latent_prior_dict

    def pool_latent(self, latent_prior_dict, xs_dict, ys_dict, ms_dict, ids_keep):
        if not self.context_encoder:
            return latent_prior_dict

        latent_init_dict = {}
        n_tokens = [0]
        contexts = []
        masks = []
        for mode in self.modes:
            tokens = self.embed(xs_dict[mode], ys_dict[mode], mode)
            mask = repeat(ms_dict[mode], 'B nK -> B H nQ nK', H = 1, nQ = 1)
            tokens = self.context_encoder_um[mode](tokens, mask = mask)
            contexts += [tokens]
            masks += [mask]
            n_tokens += [n_tokens[-1] + tokens.shape[1]]

        contexts = torch.cat(contexts, dim = 1)
        masks = torch.cat(masks, dim = -1)
        contexts = self.context_encoder_mm(contexts, masks if self.context_encoder_type == 'transformer' else None)

        if self.context_encoder_type in ['transformer', 'perceiver']:
            for mode in self.modes:
                latent = latent_prior_dict[mode]
                latent = self.context_pooler_latent_to_token[mode](latent)
                latent = latent + self.context_pooler_pos_embeds[mode]

                latent = self.context_pooler(latent, context = contexts, mask = masks if self.context_encoder_type == 'transformer' else None)
                latent = self.context_pooler_token_to_latent[mode](latent)
                latent_init_dict[mode] = latent
        elif self.context_encoder_type == 'perceiver-nopool':
            for i, mode in enumerate(self.modes):
                latent = contexts[:, n_tokens[i]:n_tokens[i+1], :]
                latent = self.context_pooler_token_to_latent[mode](latent)
                latent_init_dict[mode] = latent
        else:
            raise

        return latent_init_dict

    def fuse_grad(self, grad_dict, latent_dict):
        if not self.grad_encoder and not self.alfa:
            return grad_dict, latent_dict

        elif self.grad_encoder:
            ordinal_state_dict = {}
            unimodal_state_dict = {}
            multimodal_state_dict = {}
            modified_grad_dict = {}

            n_tokens = [0]
            states = []
            for mode in self.modes:
                B, N, D = grad_dict[mode].shape

                state = [self.grad_encoder_latent_ln[mode](latent_dict[mode]),
                         self.grad_encoder_grad_ln[mode](grad_dict[mode]),]
                state = torch.cat(state, dim = -1)
                state = self.grad_encoder_projection_mlp[mode](state)
                ordinal_state_dict[mode] = state

                state += self.grad_encoder_pos_embeds[mode]
                state = self.grad_encoder_um[mode](state)
                if self.grad_encoder_fuser:
                    unimodal_state_dict[mode] = state

                if self.grad_encoder_mm_attn_type == 'spatial':
                    n_tokens += [n_tokens[-1] + state.shape[1]]
                else:
                    raise NotImplementedError

                states += [state]

            states = torch.cat(states, dim=self.grad_encoder_mm_cat_dim)
            states = self.grad_encoder_mm(states)

            for i, mode in enumerate(self.modes):
                state = states[:, n_tokens[i]:n_tokens[i+1], :]
                multimodal_state_dict[mode] = state

            for mode in self.modes:
                if self.grad_encoder_fuser:
                    state_features = torch.cat([
                        ordinal_state_dict[mode],
                        unimodal_state_dict[mode],
                        multimodal_state_dict[mode],
                    ], dim = -1)
                    modified_grad_dict[mode] = self.grad_encoder_state_to_gradient[mode](self.grad_encoder_fuser[mode](state_features))
                else:
                    modified_grad_dict[mode] = self.grad_encoder_state_to_gradient[mode](multimodal_state_dict[mode])

            return modified_grad_dict, latent_dict

        elif self.alfa:
            modified_latent_dict = {}
            modified_grad_dict = {}
            for mode in self.modes:
                states = torch.cat([grad_dict[mode].mean(-1), latent_dict[mode].mean(-1)], -1).flatten(1)
                states = self.alfa[mode](states)
                states = states.reshape(states.shape[0], -1, 1, 2)
                beta, alpha = states[..., 0], states[..., 1]
                modified_latent_dict[mode] = self.beta_init_dict[mode] * beta * latent_dict[mode]
                modified_grad_dict[mode] = alpha * grad_dict[mode]

            return modified_grad_dict, modified_latent_dict

    def modulated_forward_single(self, x, latent, mode):
        # 1D - (bsz, lss, ld) -> (bsz, D, lss)
        # 2D - (bsz, lss * lss, ld) -> (bsz, D, lss, lss)
        modulations = latent

        if 'composer' in self.inr_type:
            x = self.inr[mode].lowrank_modulated_forward(x, modulations)
        else:
            x = self.inr[mode](x)

        return x

class Pooler(nn.Module):
    def __init__(
        self,
        dim,
        context_dim,
        depth,
        heads,
        dim_head,
        dropout = 0.,
    ) -> None:
        super().__init__()

        self.cross_attn_layers = nn.ModuleList()
        for _ in range(depth):
            self.cross_attn_layers.append(
                PreNorm(dim, Attention(dim, context_dim, heads, dim_head, dropout), context_dim = context_dim)
            )

    def forward(self, x, context, mask = None, topk = None):
        for attn in self.cross_attn_layers:
            x = x + attn(x, context = context, mask = mask, topk = topk)

        return x
