import torch
from torch import nn

from utils import common, logs_handler, misc
from models import SwinTransformer3D, GPT2, ConvEncoder
from models.modules import SequenceEmbedding, PredictionNetwork, PredictionHead
from learn import losses

logger = logs_handler.get_logger(__name__)

class TokensEmbeddings(nn.Module):
    def __init__(self, num_tokens, dim, discrete):
        super(TokensEmbeddings, self).__init__()
        self.num_tokens = num_tokens
        self.dim = dim
        self.discrete = discrete        
        if discrete:
            num_tokens += 1
        self.embeddings =\
            nn.Sequential(SequenceEmbedding(num_tokens, dim, discrete), 
                                            nn.Tanh())
                    
    def forward(self, tokens):
        tokens_embeds = self.embeddings(tokens)
        return tokens_embeds
    
class StateEncoder(nn.Module): 

    def __init__(self, encoder_type, embed_dim=96, num_channels=1, temporal_window=4, spatial_window=(7, 7), temporal_patches=4, 
                 patch_size=(4, 4), pool_size=(None, 1, 1), num_heads=None, depths=None, qkv_bias=True, drop_path_rate=0.1, causal=False):
        super(StateEncoder, self).__init__()
        assert encoder_type in {'conv', 'transformer'}, '...'
        self.encoder_type = encoder_type
        
        patch_size = (temporal_patches, *patch_size)
        window_size = (temporal_window, *spatial_window)
        self.state_dim = None
        self.encoder = None
        if self.encoder_type == 'transformer':
            self.encoder =\
                SwinTransformer3D(patch_size=patch_size,
                                  window_size=window_size,
                                  embed_dim=embed_dim,
                                  num_heads=num_heads,
                                  depths=depths,
                                  num_channels=num_channels,
                                  drop_path_rate=drop_path_rate,
                                  qkv_bias=qkv_bias,
                                  pool_size=pool_size,
                                  causal=causal)
                        
            logger.info(f'StateEncoder > swin/embed_dim: {embed_dim}')
            
        elif self.encoder_type == 'conv':
            self.encoder = ConvEncoder(num_channels=num_channels, 
                                       temporal_patches=temporal_patches, 
                                       pool_size=pool_size)            
        self.state_dim = self.encoder.out_features
        logger.info(f'StateEncoder > state_dim: {self.state_dim}')
                
    def forward(self, states, return_dict=False, **kwargs):
        encoder_outputs = self.encoder(states, **kwargs)
        if return_dict:
            return encoder_outputs
        return encoder_outputs.get('outputs')

class GPTModel(nn.Module): 

    def __init__(self, num_layers, context_dim, num_heads=8, attn_span=1, embed_drop=0.1, max_seq_len=128, **kwargs): 
        super(GPTModel, self).__init__()
        self.context_dim = context_dim
        self.temporal_embeder = SequenceEmbedding(max_seq_len, context_dim, True)        
        self.drop = nn.Dropout(embed_drop)
        self.gpt2 = GPT2(num_layers, context_dim, num_heads, **kwargs)
        self.attn_span = attn_span
        
    def forward(self, *inputs, return_dict=False, **kwargs):
        span = len(inputs)
        b, t = inputs[0].shape[:2]
        device = inputs[0].device
        time_idx = torch.arange(t, device=device).view(1, t).repeat_interleave(span, dim=-1)
        inputs = torch.stack(inputs, dim=2).reshape(b, t*span, self.context_dim)
        temporal_embeds = self.temporal_embeder(time_idx)
        gpt2_outputs = self.gpt2(self.drop(inputs + temporal_embeds), attn_span=self.attn_span, **kwargs)
        gpt2_outputs['outputs'] = gpt2_outputs['outputs'].reshape(b, t, span, self.context_dim)
        if return_dict:
            return gpt2_outputs
        return gpt2_outputs.get('outputs')

class ContextDecoder(nn.Module):
    def __init__(self, window, context_dim, hidden_dims, out_dim):
        super(ContextDecoder, self).__init__()
        self.window = window
        self.context_dim = context_dim
        self.out_dim = out_dim
        self.hidden_dim = context_dim
        self.network = None
        if hidden_dims is not None:
            self.network = nn.Sequential(PredictionNetwork(context_dim, hidden_dims), nn.GELU())
            self.hidden_dim = hidden_dims[-1]
        self.decoders_head = nn.ModuleList([])
        for _ in range(window):
            head = PredictionHead(self.hidden_dim, out_dim, use_norm=True)
            self.decoders_head.append(head)

    def forward(self, ct, j=None, start=None, end=None):
        w = self.window
        b = ct.size(0)
        t = ct.size(1) if (ct.ndim == 3) else 1
        if self.network is not None:
            zt = self.network(ct)
        else:
            zt = ct
        if j is None:
            start = start or 0
            end = end or w
            y = []
            for j in range(start, end):
                yj = self.decoders_head[j](zt)
                y.append(yj.reshape(b, t, self.out_dim))
            y = torch.concat(y, dim=1)
        else:
            assert 1 <= j <= w, '...'
            y = self.decoders_head[j-1](zt)
            y = y.reshape(b, t, self.out_dim)
        return y
    
class CAStRL(nn.Module):

    def __init__(self, state_encoder_cfg, context_gpt_cfg, strl_cfg, num_action_tokens, context_dim=192, num_channels=1, 
                 expander_dims=None, state_hidden_dims=None, action_hidden_dims=None, max_seq_len=32, 
                 action_discrete=True, actions_weights=None, unknown_action=None, use_actions=False, 
                 actions_pretrained=False):
        super(CAStRL, self).__init__()        
        self.stages = {'pretrain', 'finetune_bc', 'evaluate_bc'}
        
        self.state_encoder_cfg = state_encoder_cfg
        self.context_gpt_cfg = context_gpt_cfg
        self.strl_cfg = strl_cfg

        self.num_channels = num_channels
        self.num_action_tokens = num_action_tokens
        self.action_discrete = action_discrete
        self.unknown_action = unknown_action or (num_action_tokens if self.action_discrete else -100.0)
        self.action_dtype = torch.long if self.action_discrete else torch.float32
        self.use_actions = use_actions
        self.context_dim = context_dim
        self.expander_dims = expander_dims
        self.state_hidden_dims = state_hidden_dims
        self.action_hidden_dims = action_hidden_dims
        self.actions_pretrained = actions_pretrained
        
        # actions tokens embeddings
        self.actions_embeddings = TokensEmbeddings(num_tokens=num_action_tokens,
                                                   dim=context_dim,
                                                   discrete=action_discrete)
        
        # state encoders
        self.conv_encoder = StateEncoder(encoder_type='conv', num_channels=num_channels, **state_encoder_cfg)
        self.transformer_encoder = StateEncoder(encoder_type='transformer', num_channels=num_channels, **state_encoder_cfg)

        # expanders
        self.conv_expander = PredictionNetwork(self.conv_encoder.state_dim, expander_dims, use_norm=True, drop=0.0)
        self.transformer_expander = PredictionNetwork(self.transformer_encoder.state_dim, expander_dims, use_norm=True, drop=0.0)
        
        self.state_dim = self.transformer_encoder.state_dim
        
        # state-context projection [intermediate context representation]
        self.state_context_proj = nn.Sequential(PredictionHead(self.state_dim, context_dim, use_norm=True), nn.Tanh())

        # summarize all transitions [context prediction]
        self.context_gpt = GPTModel(context_dim=context_dim, 
                                    max_seq_len=2*max_seq_len,
                                    **context_gpt_cfg)

        # context to state [decoders networks]
        self.context2state_networks = ContextDecoder(strl_cfg['window'], context_dim, 
                                                     state_hidden_dims, self.state_dim)

        # context to action [decoders networks]
        self.context2action_networks = ContextDecoder(strl_cfg['window'], 2*context_dim, 
                                                      action_hidden_dims, num_action_tokens)

        # masked state prediction
        self.state_head = PredictionHead(context_dim, self.state_dim, use_norm=True)
        
        # masked action prediction
        self.action_head = PredictionHead(2*context_dim, num_action_tokens, use_norm=True)
        
        # behaviour cloning head [finetune]
        self.bc_head = PredictionHead(2*context_dim, num_action_tokens, use_norm=True)
                        
        if actions_weights is not None:
            self.register_buffer('actions_weights', torch.tensor(actions_weights, dtype=torch.float32))
        else:
            self.actions_weights = actions_weights
            
        self.apply(misc.init_weights)
    
    def process_actions(self, actions, batch_size, seq_len, device):
        assert (actions is None) or (actions.size(1) == seq_len), '...'
        if actions is None:
            actions = torch.full((batch_size, seq_len), fill_value=self.unknown_action, 
                                 dtype=self.action_dtype, device=device)
        return actions.type(self.action_dtype)
    
    def process_transitions_stack(self, states):
        b, t, c, h, w = states.shape
        temporal_patches = self.state_encoder_cfg['temporal_patches']
        assert c == (temporal_patches*self.num_channels), '...'
        states = states.reshape(b, t, temporal_patches, self.num_channels, h, w)
        states = states.reshape(b, t*temporal_patches, self.num_channels, h, w)
        return states
    
    def values2actions(self, values, **kwargs):
        if self.action_discrete:
            y, certainty = misc.make_discrete_actions(values, **kwargs)
        else:
            y, certainty = misc.make_continuous_actions(values, **kwargs)
        return y, certainty
        
    def predict_transitions_context(self, state_rep, input_actions):
        state_ctx = self.state_context_proj(state_rep)
        action_embeds = self.actions_embeddings(input_actions)
        x = self.context_gpt(state_ctx, action_embeds)
        trans_ctx = torch.concat([x[:,:,0], x[:,:,1]], dim=-1)
        return trans_ctx
        
    def compute_cgpt_loss(self, trans_ctx, state_rep, actions=None, 
                          gt_masked_states=None, gt_masked_prev_actions=None):
        b, t, d = state_rep.shape
        w = self.strl_cfg['mask_window']
        if not self.strl_cfg['context_type'].startswith('masked_'):
            w = self.strl_cfg['window']
            assert t - w > 0, '...'
        
        if self.strl_cfg['context_type'] == 'next_state':
            cgpt_loss = 0.0
            state_preds = self.context2state_networks(trans_ctx[:,:t-w,:self.context_dim], end=w)
            state_preds = state_preds.reshape(b, w, t-w, d)
            for tk in range(w):
                za, zb = state_rep[:,tk+1:t-w+tk+1], state_preds[:,tk]
                cgpt_loss += ((w - tk) / w) * losses.vic_loss(za, zb, **self.strl_cfg)

        elif self.strl_cfg['context_type'] == 'next_action':
            cgpt_loss = 0.0
            action_preds = self.context2action_networks(trans_ctx[:,:t-w], end=w)
            action_preds = action_preds.reshape(b, w, t-w, self.num_action_tokens) 
            for tk in range(w):
                cgpt_loss += ((w - tk) / w) * losses.action_prediction_loss(action_preds[:,tk], actions[:,tk:t-w+tk], 
                                                                            self.unknown_action, self.action_discrete, 
                                                                            self.actions_weights)            
        elif self.strl_cfg['context_type'] == 'masked_state':
            state_preds = self.state_head(trans_ctx[:,:,:self.context_dim])
            za = state_preds[gt_masked_states != -float('inf')].reshape(-1, self.state_dim)
            zb = gt_masked_states[gt_masked_states != -float('inf')].reshape(-1, self.state_dim)
            cgpt_loss = losses.vic_loss(za, zb, **self.strl_cfg)
            
        elif self.strl_cfg['context_type'] == 'masked_action':
            action_preds = self.action_head(trans_ctx)
            cgpt_loss = losses.action_prediction_loss(action_preds, gt_masked_prev_actions, self.unknown_action, 
                                                      self.action_discrete, self.actions_weights)
        return cgpt_loss
    
    def pretrain_step(self, states, states_prime, actions=None, **_):            
        assert self.strl_cfg['context_type'] in {'none', 'next_state', 'next_action',  'masked_state', 'masked_action'}, '...'
        assert (self.strl_cfg['context_type'] not in {'action', 'masked_action'}) or (actions is not None), '...'
                    
        flat_states = self.process_transitions_stack(states)
        flat_states_prime = self.process_transitions_stack(states_prime)
        
        state_rep_a = self.conv_encoder(flat_states)
        state_rep_b = self.transformer_encoder(flat_states_prime)
        
        za = self.conv_expander(state_rep_a)
        zb = self.transformer_expander(state_rep_b)
        
        out = {'state_rep_loss': losses.vic_loss(za, zb, **self.strl_cfg)}
        
        if self.strl_cfg['context_type'] != 'none':
            b, t = states.shape[:2]
            processed_actions = self.process_actions(actions, b, t, states.device)
            if self.use_actions:
                prev_actions = misc.mask_tensor_causal(processed_actions, self.unknown_action, forward=True, k=1)
                masked_prev_actions, actions_mask = misc.mask_tensor_random(prev_actions, self.unknown_action, 
                                                                            k=self.strl_cfg['mask_window'])
                gt_masked_prev_actions = prev_actions.masked_fill(~actions_mask, self.unknown_action)
            else:
                masked_prev_actions = self.process_actions(None, b, t, states.device)
                gt_masked_prev_actions = misc.mask_tensor_causal(processed_actions, self.unknown_action, forward=True, k=1)
                
            masked_state_rep, states_mask = misc.mask_tensor_random(state_rep_b, 0.0, k=self.strl_cfg['mask_window'])
            gt_masked_states = state_rep_b.masked_fill(~states_mask, -float('inf'))
            
            trans_ctx = self.predict_transitions_context(masked_state_rep, masked_prev_actions)
            
            out['cgpt_loss'] = self.compute_cgpt_loss(trans_ctx, state_rep_b, processed_actions,
                                                      gt_masked_states, gt_masked_prev_actions)
        out['loss'] = misc.accumulate_losses(out, average=True)
        return out
    
    def finetune_bc_step(self, states, actions, **_):
        b, t = states.shape[:2]
        actions = actions.type(self.action_dtype)
        if self.use_actions:
            prev_actions = misc.mask_tensor_causal(actions, self.unknown_action, forward=True, k=1)
        else:
            prev_actions = self.process_actions(None, b, t, states.device)

        flat_states = self.process_transitions_stack(states)
        
        state_rep = self.transformer_encoder(flat_states)
        trans_ctx = self.predict_transitions_context(state_rep, prev_actions)
        action_preds = self.bc_head(trans_ctx)

        bc_loss = losses.action_prediction_loss(action_preds, actions, self.unknown_action, 
                                                self.action_discrete, self.actions_weights)
        out = {'bc_loss': bc_loss}
        out['loss'] = misc.accumulate_losses(out, average=True)
        return out
    
    @torch.no_grad()
    def evaluate_bc_step(self, states, actions=None, **kwargs):
        b, t = states.shape[:2]
        actions = actions.type(self.action_dtype)
        if self.use_actions:
            prev_actions = actions
        else:
            prev_actions = self.process_actions(None, b, t, states.device)
        flat_states = self.process_transitions_stack(states)
        state_rep = self.transformer_encoder(flat_states)
        trans_ctx = self.predict_transitions_context(state_rep, prev_actions)
        action_preds = self.bc_head(trans_ctx)
        y, certainty = self.values2actions(action_preds, **kwargs)
        out = {'actions': y, 'certainty': certainty}
        return out
    
    def forward(self, stage, *args, **kwargs):
        assert stage in self.stages
        if stage == 'pretrain':
            return self.pretrain_step(*args, **kwargs)
        elif stage == 'finetune_bc':
            return self.finetune_bc_step(*args, **kwargs)
        elif stage == 'evaluate_bc':
            return self.evaluate_bc_step(*args, **kwargs)

    def freeze(self, freeze_all=True):
        modules = [self.transformer_encoder]
        if freeze_all:
            modules.extend([self.state_context_proj, self.context_gpt])
            if self.actions_pretrained:
                modules.append(self.actions_embeddings)
        for module in modules:
            common.set_requires_grad(module, False)
    
    def unfreeze(self, unfreeze_all=True):
        modules = [self.state_context_proj, self.actions_embeddings, 
                   self.context_gpt]
        if unfreeze_all:
            modules.extend([self.transformer_encoder])
        for module in modules:
            common.set_requires_grad(module, True)
