import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.utils.checkpoint as cp

from collections import OrderedDict
from . import encoders
from . import classifiers
from . import decoders
from .modules import Module, BatchNorm2d, get_child_dict

import math
import sys
sys.path.append('..')
from train.util import gaussian_sampling

def make(enc_name, enc_args, clf_name, clf_args, image_size, device, 
            prompt=False, prompt_args=None, dec_name=None, dec_args=None):
  """
  Initializes a random meta model.

  Args:
    enc_name (str): name of the encoder (e.g., 'resnet12').
    enc_args (dict): arguments for the encoder.
    clf_name (str): name of the classifier (e.g., 'meta-nn').
    clf_args (dict): arguments for the classifier.

  Returns:
    model (MAML): a meta classifier with a random encoder.
  """
  enc = encoders.make(device, enc_name, **enc_args)
  
  
  if prompt:
    # clf_args['in_dim'] = enc.get_out_dim(image_size)+prompt_args['dim']
    dec_args['in_dim'] = enc.get_out_dim(image_size)+prompt_args['dim']
    # dec_args['out_dim'] = dec_args['in_dim']
    dec_args['out_dim'] = enc.get_out_dim(image_size)
    dec = decoders.make(device, dec_name, **dec_args)
    clf_args['in_dim'] = enc.get_out_dim(image_size)
    clf = classifiers.make(device, clf_name, **clf_args)
    model = MAML(enc, clf, prompt, prompt_args, dec, device=device)
  else:
    clf_args['in_dim'] = enc.get_out_dim(image_size)
    clf = classifiers.make(device, clf_name, **clf_args)
    model = MAML(enc, clf, prompt, prompt_args, device=device)
  
  return model


class MAML(Module):
  def __init__(self, encoder, classifier, prompt=False, prompt_args=None, decoder=None, device=None):
    super(MAML, self).__init__()
    self.encoder = encoder
    self.classifier = classifier
    self.is_prompt = prompt
    self.is_bayesian = prompt_args['bayesian'] if self.is_prompt else False
    self.dynamic_lr = prompt_args['dynamic_lr'] if self.is_prompt else False
    self.is_reparam = prompt_args['reparameterization'] if self.is_prompt else False
    self.reparam_way = prompt_args['reparam_way'] if self.is_reparam else None
    # self.training_fix = prompt_args['training_fix'] if self.is_prompt else False
    self.prune = prompt_args['prune'] if self.is_prompt else False
    self.prune_rate = prompt_args['prune_rate'] if self.is_prompt else 0.0

    if self.is_prompt:
      self.prompt = nn.Parameter(torch.randn(1, prompt_args['dim']).to(device))
      self.decoder = decoder
    if self.is_bayesian:
      self.prompt_mean = nn.Parameter(torch.randn(1, prompt_args['dim']).to(device))
      self.prompt_cov = nn.Parameter(torch.randn(1, prompt_args['dim']).to(device)+0.3)
    if self.is_reparam:
      self.reparam_emb = nn.Linear(prompt_args['dim'], prompt_args['dim'], bias=False).to(device)
  
  def reset_classifier(self):
    self.classifier.reset_parameters()
  
  def forward(self, x_shot, x_query, y_shot, inner_args, meta_train, return_prompt=False):
    """
    Args:
      x_shot (float tensor, [n_episode, n_way * n_shot, C, H, W]): support sets.
      x_query (float tensor, [n_episode, n_way * n_query, C, H, W]): query sets.
        (T: transforms, C: channels, H: height, W: width)
      y_shot (int tensor, [n_episode, n_way * n_shot]): support set labels.
      inner_args (dict, optional): inner-loop hyperparameters.
      meta_train (bool): if True, the model is in meta-training.
      
    Returns:
      logits (float tensor, [n_episode, n_way * n_shot, n_way]): predicted logits.
    """

    assert self.encoder is not None
    assert self.classifier is not None
    assert x_shot.dim() == 5 and x_query.dim() == 5
    assert x_shot.size(0) == x_query.size(0)

    # a dictionary of parameters that will be updated in the inner loop
    params = OrderedDict(self.named_parameters())
    

    for name in list(params.keys()):
      if not params[name].requires_grad or \
        any(s in name for s in inner_args['frozen'] + ['temp']):
        params.pop(name)

    # for bayesian
    if self.is_bayesian:
      params.pop('prompt_mean')
      params.pop('prompt_cov')
      prompt_sample = gaussian_sampling(self.prompt_mean, torch.log(1. + self.prompt_cov.exp()))
      params['prompt'] = prompt_sample
    if self.is_reparam:
      params.pop('reparam_emb.weight')

    logits = []
    for ep in range(x_shot.size(0)):
      # inner-loop training
      self.train()
      self.meta_train = True

      # set for the meta-test
      if not meta_train:
        for m in self.modules():
          if isinstance(m, BatchNorm2d) and not m.is_episodic():
            m.eval()
        self.meta_train = False
      
      # inner adaption
      
      updated_params = self._adapt(
        x_shot[ep], y_shot[ep], params, ep, inner_args, meta_train)
      
      print(torch.cuda.memory_allocated())
      print(torch.cuda.max_memory_allocated())
      print(torch.cuda.memory_allocated()/torch.cuda.max_memory_allocated())
      # inner-loop validation
      with torch.set_grad_enabled(meta_train):
        self.eval()
        logits_ep = self._inner_forward(x_query[ep], updated_params, ep)
      logits.append(logits_ep)

    self.train(meta_train)
    logits = torch.stack(logits)
    if return_prompt:
        return updated_params['prompt'].cpu().detach().numpy()
    elif self.is_bayesian:
        return logits, prompt_sample
    else:
        return logits, torch.zeros(1)


  def _adapt(self, x, y, params, episode, inner_args, meta_train):
    """
    Performs inner-loop adaptation in MAML.

    Args:
      x (float tensor, [n_way * n_shot, C, H, W]): per-episode support set.
        (T: transforms, C: channels, H: height, W: width)
      y (int tensor, [n_way * n_shot]): per-episode support set labels.
      params (dict): a dictionary of parameters at meta-initialization.
      episode (int): the current episode index.
      inner_args (dict): inner-loop optimization hyperparameters.
      meta_train (bool): if True, the model is in meta-training.
      
    Returns:
      params (dict): model paramters AFTER inner-loop adaptation.
    """
    assert x.dim() == 4 and y.dim() == 1
    assert x.size(0) == y.size(0)
    
    # Initializes a dictionary of momentum buffer for gradient descent in the 
    # inner loop. It has the same set of keys as the parameter dictionary.
    mom_buffer = OrderedDict()
    if inner_args['momentum'] > 0:
      for name, param in params.items():
        mom_buffer[name] = torch.zeros_like(param)
    params_keys = tuple(params.keys())
    mom_buffer_keys = tuple(mom_buffer.keys())

    for m in self.modules():
      if isinstance(m, BatchNorm2d) and m.is_episodic():
        m.reset_episodic_running_stats(episode)

    def _inner_iter_cp(episode, *state):
      """ 
      Performs one inner-loop iteration when checkpointing is enabled. 
      The code is executed twice:
        - 1st time with torch.no_grad() for creating checkpoints.
        - 2nd time with torch.enable_grad() for computing gradients.
      """
      params = OrderedDict(zip(params_keys, state[:len(params_keys)]))
      mom_buffer = OrderedDict(
        zip(mom_buffer_keys, state[-len(mom_buffer_keys):]))

      detach = not torch.is_grad_enabled()  # detach graph in the first pass
      self.is_first_pass(detach)
      params, mom_buffer = self._inner_iter(
        x, y, params, mom_buffer, int(episode), inner_args, detach)
      state = tuple(t if t.requires_grad else t.clone().requires_grad_(True)
        for t in tuple(params.values()) + tuple(mom_buffer.values()))
      return state

    # inner updating
    for step in range(inner_args['n_step']):
      print(step)
      if self.efficient:  # checkpointing
        #! we do not consider the efficient one
        state = tuple(params.values()) + tuple(mom_buffer.values())
        state = cp.checkpoint(_inner_iter_cp, torch.as_tensor(episode), *state)
        params = OrderedDict(zip(params_keys, state[:len(params_keys)]))
        mom_buffer = OrderedDict(
          zip(mom_buffer_keys, state[-len(mom_buffer_keys):]))
      else:  # without checkpointing
        #print('step:{}'.format(step))
        params, mom_buffer = self._inner_iter(
          x, y, params, mom_buffer, episode, inner_args, not meta_train)
        
    return params

  
  def _inner_iter(self, x, y, params, mom_buffer, episode, inner_args, detach):
    """ 
    Performs one inner-loop iteration of MAML including the forward and 
    backward passes and the parameter update.

    Args:
      x (float tensor, [n_way * n_shot, C, H, W]): per-episode support set.
      y (int tensor, [n_way * n_shot]): per-episode support set labels.
      params (dict): the model parameters BEFORE the update.
      mom_buffer (dict): the momentum buffer BEFORE the update.
      episode (int): the current episode index.
      inner_args (dict): inner-loop optimization hyperparameters.
      detach (bool): if True, detachs the graph for the current iteration.

    Returns:
      updated_params (dict): the model parameters AFTER the update.
      mom_buffer (dict): the momentum buffer AFTER the update.
    """
    with torch.enable_grad():
      # forward pass
      logits = self._inner_forward(x, params, episode)
      loss = F.cross_entropy(logits, y)
      #print(loss)
      
      # backward pass
      grads = autograd.grad(loss, params.values(), 
        create_graph=(not detach and not inner_args['first_order']),
        only_inputs=True, allow_unused=True)
      # parameter update
      updated_params = OrderedDict()
      for (name, param), grad in zip(params.items(), grads):
        
        if name == 'reparam_emb':
          raise RuntimeError('the embedding layer cannot be updated.')

        if grad is None:
          updated_param = param
        else:
          if inner_args['weight_decay'] > 0:
            grad = grad + inner_args['weight_decay'] * param
          if inner_args['momentum'] > 0:
            grad = grad + inner_args['momentum'] * mom_buffer[name]
            mom_buffer[name] = grad
          if 'encoder' in name:
            lr = inner_args['encoder_lr']
          elif 'classifier' in name:
            lr = inner_args['classifier_lr']
          elif 'decoder' in name:
            lr = inner_args['decoder_lr']
          elif 'prompt' in name:
            if inner_args["prompt_forze"]:
              lr = 0
            elif self.dynamic_lr and not self.meta_train and self.is_bayesian:
                if self.prune:
                  # hard modulation
                  importances = torch.abs(self.prompt_mean.detach())/torch.log(1.+self.prompt_cov.detach().exp())
                  importances = self.mask_tensor(importances)
                  lr = inner_args['prompt_lr']
                else:
                  # soft modulation
                  # lr = self.prompt_cov.detach()*inner_args['prompt_lr']
                  lr = torch.log(1.+self.prompt_cov.detach().exp())*inner_args['prompt_lr']
            else:
                lr = inner_args['prompt_lr']
          else:
            raise ValueError('invalid parameter name')
          updated_param = param - lr * grad
        if detach:
          updated_param = updated_param.detach().requires_grad_(True)
        updated_params[name] = updated_param

    return updated_params, mom_buffer
  
  def _inner_forward(self, x, params, episode):
    """ Forward pass for the inner loop. """
    feat = self.encoder(x, get_child_dict(params, 'encoder'), episode)
    if self.is_prompt:
      
      if self.is_reparam:
        if self.reparam_way == "MLP":
          reparam_prompt = self.reparam_emb(params['prompt'])
        elif self.reparam_way == "res":
          reparam_prompt = params['prompt'] + self.reparam_emb(params['prompt'])
        else:
          raise NameError('None is reproduced.')
        feat_prompt = torch.cat([feat, reparam_prompt.repeat(feat.size()[0], 1)], dim=1)
      else:
        feat_prompt = torch.cat([feat, params['prompt'].repeat(feat.size()[0], 1)], dim=1)
      merge_feat = self.decoder(feat_prompt, get_child_dict(params, 'decoder'), episode)
      logits = self.classifier(merge_feat, get_child_dict(params, 'classifier'))
    else:
      logits = self.classifier(feat, get_child_dict(params, 'classifier'))

    return logits

  def mask_tensor(self, importances, largest=True):

    if largest:
      mask_nums = int(math.floor(len(importances)*self.prune_rate))
      mask_indexes = torch.topk(importances, mask_nums)[1]
      importances[mask_indexes] = 0.
    else:
      pass

    return importances
