import jax.numpy as jnp
import flax.linen as nn
from networks.mlp import ID, MLP
from networks.direct_pred import DP


class ssl_agent(nn.Module):
    encoder: nn.Module
    proj_type: str
    pred_type: str
    greedy: bool
    dp_args: dict
    proj_args: dict
    num_classes: int
    dataset: str
    iso: bool
    
    def setup(self):
        size = self.proj_args['hidden_sizes'][-1]
        if self.pred_type=='mlp':
            self.pred_args = {'hidden_sizes': (size//4, size),
                              'bnorm': (True, False),
                              'detach_head': self.iso}
        elif self.pred_type=='dp':
            self.pred_args = self.dp_args
        self.lep_args = {'hidden_sizes': (self.num_classes,),
                         'bnorm': (False,),
                         'detach_head': True,
                         'act': lambda x: x}
        
    @nn.compact
    def __call__(self, x, train=True, is_target_net=False):
    
        projs, preds, logits_lep = {}, {}, {}
        acts = self.encoder(greedy=self.greedy, dataset=self.dataset)(x, train=train)
        
        if self.proj_type=='mlp':
            proj_module = MLP
            proj_args = self.proj_args
        elif self.proj_type=='id':
            proj_module = ID
            proj_args = {}
        
        if self.pred_type=='mlp':
            pred_module = MLP
            pred_args, add_kwargs = self.pred_args, {}
        elif self.pred_type=='dp':
            pred_module = DP
            pred_args = self.pred_args
            add_kwargs = {'is_target_net': is_target_net}
        elif self.pred_type=='id':
            pred_module = ID
            pred_args, add_kwargs = {}, {}
            
        for key, y in acts.items():
            # average pooling
            y = jnp.mean(y, axis=(1, 2))
            z = proj_module(name=f'proj_{key}', **proj_args)(y, train=train)
            h = pred_module(name=f'pred_{key}', **pred_args)(z, train=train, **add_kwargs)
            projs[key], preds[key] = z, h
        
        # LEP on last layer acts and proj
        logits_lep['embd'] = MLP(name=f'lep_embedding', **self.lep_args)(y, train=train)
        logits_lep['proj'] = MLP(name=f'lep_projection', **self.lep_args)(z, train=train)

        return acts, projs, preds, logits_lep
    

class supervised_agent(ssl_agent):

    def setup(self):
        super().setup()
        self.classifier_args = {'hidden_sizes': (self.num_classes,), 'bnorm': (False,)}
        
    @nn.compact
    def __call__(self, x, train=True):
    
        acts, projs, _, logits_lep = super().__call__(x, train=train)
        
        preds = {}
        for key, z in projs.items():
            preds[key] = MLP(name=f'class_{key}', **self.classifier_args)(z, train=train)

        return acts, projs, preds, logits_lep
