import torch
import numpy as np
import copy
import torch.nn as nn


def configure_model_bn(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)
    # configure norm for tent updates: enable grad + force batch statisics
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.requires_grad_(True)

    return model

def configure_model_noadapt(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.eval()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)

    return model

def configure_model_whole(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(True)

    return model

def configure_model_tent(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)
    # configure norm for tent updates: enable grad + force batch statisics
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.requires_grad_(True)
            # force use of batch stats in train and eval modes

            m.track_running_stats = False
            m.running_mean = None
            m.running_var = None
    return model



def get_img_num_per_cls(num_examples, cls_num, imb_factor, imb_type='exp'):
        img_max = num_examples / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls


def make_LT_datasets(x, y, rho, num_classes):
    x_mod = []
    y_mod = []
    
    num_cls = get_img_num_per_cls(x.size(0), num_classes, rho)

    for c in range(num_classes):
        x_c = x[y==c]
        y_c = y[y==c]

        idx_c = torch.randperm(x.size(0)//num_classes)[:num_cls[c]]
        x_mod.append(x_c[idx_c])
        y_mod.append(y_c[idx_c])
    x_mod, y_mod = torch.cat(x_mod, 0), torch.cat(y_mod, 0)

    return x_mod, y_mod, num_cls