import numpy as np
import torch 
from torch import nn
import copy
from tqdm import tqdm

from misc import batchify
import ipdb
import itertools
from rich import pretty
from torch.distributions import Dirichlet
pretty.install()

class BAG(nn.Module):
  def __init__(self, n_batch_envs, input_dim, Phi, config, out_dim=1, phi_dim = None):
    super(BAG, self).__init__()

    self.n_batch_envs = n_batch_envs
    self.input_dim = input_dim # input dimension is useless
    self.classification = config.classification
    
    
    # Define \Phi
    self.Phi = copy.deepcopy(Phi)
    

    if not phi_dim:
      self.phi_odim = self.Phi[-1].out_features
    else:
      self.phi_odim = phi_dim
    
    # Define \beta
    # \beta is a diagonal matrix, shape is (self.phi_odim[8], out_dim[7])
    self.config = config
    self.z_dim = config.z_dim
    self.z_c_dim = config.z_c_dim
    self.z_s_dim = config.z_s_dim
    self.out_dim = out_dim
    self.hide_dim = config.hide_dim
    


    # Define \eta
    # \eta is a ParameterList of n_batch_envs[5], each element is a matrix of shape (self.phi_odim, out_dim)
    # self.etas = nn.ParameterList([torch.nn.Parameter(torch.zeros(self.z_s_dim, out_dim), requires_grad = True) for i in range(n_batch_envs)]) 

    self.softmax_layer = nn.Softmax(dim=-1) 
    # ipdb.set_trace()

    self.environment_num = config.environment_num
    self.environment_dim = config.environment_dim
    self.encoder = nn.Sequential(
            nn.Linear(self.phi_odim, self.hide_dim),
            nn.BatchNorm1d(self.hide_dim),
            nn.ReLU(),
            nn.Dropout())
    self.decoder = nn.Sequential(nn.Linear(self.z_dim, self.hide_dim),
                                     nn.BatchNorm1d(self.hide_dim),
                                     nn.ReLU(),
                                     nn.Linear(self.hide_dim, self.phi_odim))
    self.fc_mu = nn.Sequential(nn.Linear(self.hide_dim, self.z_dim))
    self.fc_logvar = nn.Sequential(nn.Linear(self.hide_dim, self.z_dim))
    
    self.environment_idnex_predictor = nn.Sequential(
            nn.Linear(self.z_s_dim, self.environment_num),
            nn.Dropout()
            )
    self.environment_softmax = nn.Softmax(dim=-1)
    self.environment_embedding = nn.ParameterList([torch.nn.Parameter(torch.zeros(self.environment_dim),
                                                                      requires_grad = True) for i in range(self.environment_num)])
    self.label_predict_with_z_s = nn.Sequential(
            nn.Linear(self.z_s_dim+self.environment_dim,out_dim),
            nn.Dropout())
    
    self.label_predict_with_z_c = nn.Sequential(  
            nn.Linear(self.z_c_dim,out_dim),
            nn.Dropout())
    

    self.y_distribution = nn.Parameter(torch.full((self.out_dim,), torch.logit(torch.tensor(1/self.out_dim)), requires_grad=True))
    pretty.pprint(f"out dim is {self.out_dim}")
    
    # self.check_var_with_required_grad()
    
    
    
  def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
  def get_parameters(self, base_lr=1.0):
        """
        Return parameter groups for optimizer with different learning rates.
        :param base_lr: Base learning rate.
        :return: List of parameter groups.
        """
        # Parameters for Phi
        phi_params = list(self.Phi.parameters())

        # Parameters for other components
        other_params = itertools.chain(
            self.encoder.parameters(),
            self.environment_idnex_predictor.parameters(),
            self.environment_embedding.parameters(),
            self.label_predict_with_z_s.parameters(),
            self.label_predict_with_z_c.parameters(),
            self.decoder.parameters(),
            self.fc_mu.parameters(),
            self.fc_logvar.parameters(),
            [self.y_distribution],
        )

        param_groups = [
            {"params": phi_params, "lr": 1 * base_lr},  # Phi with reduced learning rate
            {"params": other_params, "lr": 1.0 * base_lr},  # Other parameters with base learning rate
        ]
        return param_groups

  def check_var_with_required_grad(self):
        """
        Check which parameters require gradients.
        """
        for name, param in self.named_parameters():
            print(f"{name} requires_grad: {param.requires_grad}")
  
  def forward(self, input_data, env_index, rep_learning=False, fast_eta=None,confusion_matrix=None,test_tta=False):

      representation = self.Phi(input_data)
        

      h = self.encoder(representation)
      
      if self.training and not test_tta:
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        latent_z = self.reparameterize(mu, logvar)
        
      else:
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        latent_z = self.fc_mu(h)
      
      representation_hat = self.decoder(latent_z)

      z_c = latent_z[..., :self.z_c_dim]
      z_s = latent_z[..., self.z_c_dim:]

      
      prediction_c = self.label_predict_with_z_c(z_c)

      env_selection_probs = self.environment_softmax(self.environment_idnex_predictor(z_s))

      env_based_predictions = []
      for i in range(self.environment_num):

          env_feat = self.environment_embedding[i].unsqueeze(0).expand(z_s.size(0), -1)

          z_s_with_env = torch.cat([z_s, env_feat], dim=1)

          label_logits = self.label_predict_with_z_s(z_s_with_env)
          env_based_predictions.append(label_logits)


      env_based_pred_stack = torch.stack(env_based_predictions, dim=1)

      env_selection_probs_expanded = env_selection_probs.unsqueeze(-1)

      prediction_s = (env_based_pred_stack * env_selection_probs_expanded).sum(dim=1)
      

      prediction_c_prob = torch.nn.functional.softmax(prediction_c, dim=-1)  # p(Y = y | z_c)
      # pretty.pprint(prediction_c_prob[0])
      prediction_s_prob = torch.nn.functional.softmax(prediction_s, dim=-1)  # p(Y = y | z_s, E)
      # pretty.pprint(prediction_s_prob[0])
      prediction_y_prob = torch.nn.functional.softmax(self.y_distribution, dim=-1)  # p(Y = y)
      # pretty.pprint(prediction_y_prob)
      if confusion_matrix is not None:
        confusion_matrix = torch.tensor(confusion_matrix).float().to(z_s.device)
        p, _, _, _ = torch.linalg.lstsq(confusion_matrix, prediction_s_prob.T)
        p = p.T
        prediction_s_prob = p

        prediction_s_prob = prediction_s_prob / prediction_s_prob.sum(dim=-1, keepdim=True)

      Q = prediction_s_prob * prediction_c_prob / prediction_y_prob

      logit_Q = torch.nn.functional.log_softmax(Q, dim=-1) 
      p_Y_given_zc_zs_E = torch.exp(logit_Q)  
      
      
      return prediction_c, prediction_s,self.y_distribution,p_Y_given_zc_zs_E,representation,representation_hat, mu, logvar,latent_z

  def sample_base_classifer(self, x):
    x_tensor = torch.Tensor(x)
    return self.Phi(x_tensor) @ self.beta

  """ used to free and check var """
  def freeze_all_but_etas(self):
    for para in self.parameters():
      para.requires_grad = False


  def set_etas_to_zeros(self):

    for eta in self.etas:
      eta.zero_()

  def freeze_all_but_phi(self):
    for para in self.parameters():
      para.requires_grad = True

    for eta in self.etas:
      eta.requires_grad = False
    
    self.beta.requires_grad = False

  def freeze_all_but_beta(self):
    for para in self.parameters():
      para.requires_grad = True
    
    self.beta.requires_grad = False

  def freeze_all(self):
    for para in self.parameters():
      para.requires_grad = False

  def free_all(self):
    for para in self.parameters():
      para.requires_grad = True

  def check_var_with_required_grad(self):
    """ Check what paramters are required grad """
    for name, param in self.named_parameters():
      print(f"{name} has requires_grad = {param.requires_grad}")
        
  def train_only_z_s_classfier(self):
    for para in self.parameters():
        para.requires_grad = False
    for para in self.label_predict_with_z_s.parameters():
      para.requires_grad = True
    for para in self.environment_idnex_predictor.parameters():
      para.requires_grad = True
    for para in self.environment_embedding:
      para.requires_grad = True


  def get_optimizer_for_specific_parameters(self,lr):

    trainable_parameters = (
        list(self.label_predict_with_z_s.parameters()) +
        list(self.environment_idnex_predictor.parameters()) +
        list(self.environment_embedding.parameters())
    )


    optimizer = torch.optim.Adam(trainable_parameters, lr=lr)
    
    return optimizer
