import copy
import sys
import torch
import logging
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
from torch import nn
from models import ModelWithRandPrior
from models_context import MarginalPredictorContext, SequentialPredictorContext

##############################################################################
# Functions to define optimizers 
##############################################################################

def get_optimizer(model, config):
    param_list = list(model.named_parameters())

    #{ # Zero learning rate for randomized prior weights
    #        "params": [ p for n, p in param_list if 'prior_weights' in n ],
    #        "weight_decay": 0,
    #        "lr": 0,
    #    },
    
    optimizer_parameters = [ 
        { 
            "params": [ p for n, p in param_list if 'prior_weights' not in n ],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate,
        },
    ]
    
    opt = torch.optim.AdamW(optimizer_parameters, lr=config.learning_rate,
                            weight_decay=config.weight_decay, betas=(0.9, 0.95))
    return opt 



##############################################################################
# Functions to load models
##############################################################################

def compareModelWeights(model_a, model_b):
    module_a = model_a._modules
    module_b = model_b._modules
    if len(list(module_a.keys())) != len(list(module_b.keys())):
        return False
    a_modules_names = list(module_a.keys())
    b_modules_names = list(module_b.keys())
    for i in range(len(a_modules_names)):
        layer_name_a = a_modules_names[i]
        layer_name_b = b_modules_names[i]
        if layer_name_a != layer_name_b:
            return False
        layer_a = module_a[layer_name_a]
        layer_b = module_b[layer_name_b]
        if (
            (type(layer_a) == nn.Module) or (type(layer_b) == nn.Module) or
            (type(layer_a) == nn.Sequential) or (type(layer_b) == nn.Sequential)
            ):
            if not compareModelWeights(layer_a, layer_b):
                return False
        if hasattr(layer_a, 'weight') and hasattr(layer_b, 'weight'):
            if not torch.equal(layer_a.weight.data, layer_b.weight.data):
                return False
    return True


def get_model_and_optimizer_context(config):
    is_sequential = config.marginal_vs_sequential == 'sequential'
    logging.info(f"IS SEQUENTIAL {is_sequential}")
    if not hasattr(config, 'prior_scale'):
        config.prior_scale = 0
    if not hasattr(config, 'X_MLP_width'):
        config.X_MLP_width = 0
    if not hasattr(config, 'X_MLP_layer'):
        config.X_MLP_layer = 0
    if not hasattr(config, 'suffstat_eps'):
        config.suffstat_eps = 1
    if not is_sequential and config.marginal_vs_sequential != 'marginal':
        raise ValueError('config.marginal_vs_sequential must be either marginal or sequential')
 
    if not is_sequential:
        if hasattr(config, 'prior_scale') and config.prior_scale != 0:
            model = ModelWithRandPrior(MarginalPredictorContext, config.prior_scale, 
                    Z_dim=config.Z_dim, X_dim=config.X_dim,
                    MLP_width=config.MLP_width, MLP_layer=config.MLP_layer,
                    prior_scale=config.prior_scale).to(config.device)
            print('Model with prior, with scaling prior outputs')
            # should probably refactor this to be more general. 
        else:
            model = MarginalPredictorContext(Z_dim=config.Z_dim, X_dim=config.X_dim,
                MLP_width=config.MLP_width, MLP_layer=config.MLP_layer,
                prior_scale=config.prior_scale).to(config.device)
    else:
        model = SequentialPredictorContext(Z_dim=config.Z_dim, X_dim=config.X_dim,
                MLP_width=config.MLP_width, 
                MLP_layer=config.MLP_layer,
                X_MLP_layer=config.X_MLP_layer,
                X_MLP_width=config.X_MLP_width,
                suffstat_eps=config.suffstat_eps,
                repeat_suffstat=config.repeat_suffstat).to(config.device)

    optimizer_dict = { 'all': get_optimizer(model, config) }
    return model, optimizer_dict
    

##############################################################################
# Functions to compute losses
##############################################################################

def loss_from_loss_matrix(loss_matrix, orig_click_mask, how='sum_avg_per_row', weight_factor=1):
    click_mask = orig_click_mask * weight_factor**torch.arange(loss_matrix.shape[1]).to(loss_matrix.device)
    masked_losses = loss_matrix * click_mask # click mask is always 1 in our current setup

    if how == 'avg_per_row':
        loss = masked_losses.sum(1) / click_mask.sum(1)
        return loss.mean()
    
    elif how == 'avg_per_obs':
        loss = masked_losses.sum() / click_mask.sum()
        return loss.mean()

    elif how == 'sum_avg_per_row':
        loss = masked_losses.sum(1) / click_mask.sum(1)
        return loss.sum()
    
    elif how == 'sum_per_obs':
        loss = masked_losses.sum()
        return loss.sum()

    else:
        raise ValueError('Argument "how" not accepted')



def get_val_loss(model, val_loader, device, loss_agg='sum_avg_per_row', 
                 sequential_one_length=None, weight_factor=1, exact=False,
                 embed_data=False, use_X_model=False, verbose=False):
    
    total_loss_per_t = None
    total_loss = 0
    total_rows = 0
    model.eval()
    theta_hats = []
    click_obs_means = []
    click_rates = []
    click_obs_counts = []
    click_obs = []
    click_obs_masks = []
    cat_info = []
    all_model_input = []
    encoder = hasattr(model, 'z_encoder_output_dim') and model.z_encoder_output_dim is not None
    i=0
    with torch.no_grad():
        for batch in val_loader:
            i+=1
            if verbose:
                print(f'{i} out of {len(val_loader)}')
                sys.stdout.flush()
            for k,v in batch.items():
                batch[k] = v.to(device)

            click_mask = batch['click_length_mask']

            if hasattr(model, 'z_encoder_output_dim') and (model.z_encoder_output_dim is not None or embed_data):
                model_input = batch['Z']
                all_model_input.append(model_input)
            else:
                model_input = None
            
            if use_X_model:
                loss_matrix, row_theta_hats = model.eval_seq(model_input, batch['X'], 
                                                         batch['click_obs'],
                                                         return_preds=True)
            else:        
                loss_matrix, row_theta_hats = model.eval_seq(model_input, 
                                                         batch['click_obs'],
                                                         N=None, return_preds=True, exact=exact)
 
            if sequential_one_length is not None:
                loss_matrix = loss_matrix[:,[sequential_one_length]]
                click_mask_loss = copy.deepcopy(click_mask[:,[sequential_one_length]])
                loss = loss_from_loss_matrix(loss_matrix, click_mask_loss, how=loss_agg, weight_factor=weight_factor).detach().cpu()
            else:
                loss = loss_from_loss_matrix(loss_matrix, click_mask, how=loss_agg, weight_factor=weight_factor).detach().cpu()
            
            theta_hats.append(row_theta_hats.detach().cpu())
            click_obs_means.append((batch['click_obs']*click_mask).sum(dim=1).cpu())
            click_obs.append(batch['click_obs'].cpu())
            click_obs_masks.append(batch['click_length_mask'].cpu())
            click_rates.append(batch['click_rates'].cpu())
            click_obs_counts.append(click_mask.sum(dim=1).cpu())
            
            total_loss += loss.detach().cpu().item()
            total_rows += len(batch['click_obs'])
            
            if total_loss_per_t is None:
                total_loss_per_t = loss_matrix.detach().sum(dim=0).cpu()
            else:
                total_loss_per_t += loss_matrix.detach().sum(dim=0).cpu()
                
    
    return_dict =  {
            'loss':             total_loss / total_rows, 
            'loss_per_t':       total_loss_per_t / total_rows,
            'theta_hats':       torch.concatenate(theta_hats).cpu(), 
            'click_obs_means':  torch.concatenate(click_obs_means).cpu(), 
            'click_rates':      torch.concatenate(click_rates).cpu(),
            'click_obs_counts': torch.concatenate(click_obs_counts).cpu(),
            'click_obs':        torch.concatenate(click_obs).cpu(),
            'click_obs_masks':  torch.concatenate(click_obs_masks).cpu(),
    }
        
    return return_dict

