import torch
from torch import nn
from utils.acce_utils import print_0

def weight_init(m):
    if isinstance(m, nn.Linear):
        # print(m)
        # nn.init.xavier_normal_(m.weight)
        # print_0(f'before init, '
        #       f'shape = {m.weight.shape}, '
        #       f'max = {float(torch.max(m.weight))}, '
        #       f'min = {float(torch.min(m.weight))}. ', flush=True)
        nn.init.xavier_uniform_(m.weight)
        if m.bias != None:
            nn.init.constant_(m.bias, 0)
        # print_0(f'after init,  '
        #       f'shape = {m.weight.shape}, '
        #       f'max = {float(torch.max(m.weight))}, '
        #       f'min = {float(torch.min(m.weight))}. ', flush=True)

    elif isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
        nn.init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')

    elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

def freeze(module):
    for param in module.parameters():
        param.requires_grad = False

def print_trainable_parameters(model):
    trainable_parameters = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_parameters += param.numel()
    print(f"trainable params: {trainable_parameters} || all params: {all_param} || trainable%: {100 * trainable_parameters / all_param}")
    # for name, param in model.named_parameters():
    #     if param.requires_grad:
    #         print(name, param.shape)


def unfreeze(module):
    for param in module.parameters():
        param.requires_grad = True


def freeze_llm_layers(model_base, model, num_freeze_layers):
    if num_freeze_layers <= 0:
        return

    layers = None
    if 'llava' in model_base:
        layers = model.language_model.model.layers
    elif model_base in {'gemma', 'llama', 'llama_bi' 'phi'}:
        layers = model.model.layers
    else:
        return

    for i in range(num_freeze_layers):
        if not layers:
            continue
        for param in layers[i].parameters():
            param.requires_grad = False
    print(layers)


def init_score_head(model, do_init, num_labels):
    try:
        # 非fine-tune的基座，要重新初始化 score
        if do_init and hasattr(model, 'init_score'):
            model.init_score(num_labels=num_labels)
        # if hasattr(model, 'score'):
        #     model.score.apply(weight_init)
        # if hasattr(model, 'score_mt'):
        #     model.score_mt.apply(weight_init)
        # if hasattr(model, 'score_weak'):
        #     model.score_weak.apply(weight_init)
    except:
        import traceback
        print_0(traceback.print_exc())


def init_class_weight(args, model, label2count, label2i, device, dtype):
    import collections
    import numpy as np

    if len(args.cls_config.class_weights.split(',')) == max(label2i.values()) + 1:
        model.class_weights = torch.tensor([float(x) for x in args.cls_config.class_weights.split(',')]).to(device).to(dtype)

    else:
        i2count = collections.defaultdict(float)
        for label, count in label2count.items():
            i2count[label2i.get(label, 1)] += count

        max_count = max(label2count.values())
        class_weights = []
        for i, count in sorted(i2count.items(), key=lambda x:x[0]):
            class_weights.append(max_count / count)
        print_0(class_weights)
        class_weights = np.asarray(class_weights)
        class_weights = np.sqrt(class_weights)
        class_weights /= np.mean(class_weights)
        print_0(class_weights)

        if len(class_weights) == len(label2count):
            model.class_weights = torch.tensor([float(x) for x in class_weights]).to(device)
            if dtype != 'auto':
                model.class_weights = model.class_weights.to(dtype)
            print_0(model.class_weights, flush=True)
    print_0(f'model class weights = {model.class_weights}', flush=True)


class ClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, hidden_size: int, num_labels: int, dropout: float=0.0):
        super().__init__()
        self.dense = nn.Linear(hidden_size, num_labels)
        self.dropout = nn.Dropout(dropout)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        return hidden_states


class ClassificationHeadV2(nn.Module):
    """Head for sentence-level classification tasks."""
    def __init__(self, hidden_size: int, num_labels: int, dropout: float=0.0):
        super().__init__()
        self.score = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size, bias=False),
            nn.PReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_labels, bias=False)
        )
        self.score.apply(weight_init)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.score(hidden_states)


class ClassificationHeadV3(nn.Module):
    """Head for sentence-level classification tasks."""
    def __init__(self, hidden_size: int, num_labels: int, dropout: float=0.0, scale=4):
        super().__init__()
        self.score = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size * scale, bias=False),
            # nn.LayerNorm(hidden_size * 4),
            nn.PReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * scale, hidden_size, bias=False),
            # nn.LayerNorm(hidden_size),
            nn.PReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_labels, bias=False)
        )
        self.score.apply(weight_init)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.score(hidden_states)


class ClassificationHeadV5(nn.Module):
    """Head for sentence-level classification tasks."""
    def __init__(self, hidden_size: int, num_labels: int, dropout: float=0.0):
        super().__init__()
        self.score = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2, bias=False),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, hidden_size // 4, bias=False),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 4, num_labels, bias=False)
        )
        self.score.apply(weight_init)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.score(hidden_states)


def build_mask(max_seq_len, seq_lens):
    mask = torch.arange(max_seq_len).expand(len(seq_lens), max_seq_len)
    mask = mask.to(seq_lens.device)
    mask = mask < torch.tensor(seq_lens).unsqueeze(1)
    return mask


class AttentionPooling(nn.Module):
    def __init__(self, hidden_size, drop_rate, drop_out=False):
        super(AttentionPooling, self).__init__()
        self.atn_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )
        self.drop_layer = nn.Dropout(p=drop_rate)
        self.drop_out = drop_out
        self.atn_layer.apply(weight_init)

    def forward(self, token_embeddings, lengths):
        """
        :params token_embeddings, shape = (batch, seq_len, embed_dim)
        :params attention_mask, shape = (batch, seq_len, )
        """

        batch_size = token_embeddings.shape[0]
        embedding_size = token_embeddings.shape[2]

        # wired. fc-layer must be 2-dim ?
        alpha = self.atn_layer(token_embeddings.view(-1, embedding_size))

        mask = build_mask(token_embeddings.size(1), lengths)
        mask = mask.to(token_embeddings.dtype)

        alpha = torch.exp(alpha.view(batch_size, -1))
        alpha = alpha.masked_fill(mask == False, 0.0)
        alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)

        token_embeddings = torch.bmm(token_embeddings.permute(0, 2, 1), alpha.unsqueeze(-1))
        token_embeddings = torch.reshape(token_embeddings, (batch_size, -1))

        if self.drop_out:
            token_embeddings = self.drop_layer(token_embeddings)

        return token_embeddings


class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()

    def forward(self, token_embeddings, lengths):
        # sequences: Tensor of shape (batch_size, max_seq_length, embedding_dim)
        # lengths: List of actual lengths of sequences in the batch

        # Masking
        # mask = torch.arange(token_embeddings.size(1)).expand(len(lengths), token_embeddings.size(1))
        # mask = mask.to(token_embeddings.device)
        # mask = mask < torch.tensor(lengths).unsqueeze(1)
        mask = build_mask(token_embeddings.size(1), lengths)
        mask = mask.to(token_embeddings.dtype).unsqueeze(2)  # (batch_size, max_seq_length, 1)

        # Apply mask
        masked_sequences = token_embeddings * mask

        # Calculate mean pooling
        sum_vectors = torch.sum(masked_sequences, dim=1)  # (batch_size, embedding_dim)
        mean_pooling = sum_vectors / torch.tensor(lengths).unsqueeze(1).to(token_embeddings.dtype) # Divide by actual lengths

        return mean_pooling


class LastTokenPooling(nn.Module):
    def __init__(self):
        super(LastTokenPooling, self).__init__()
        self.delta = 0

    def forward(self, token_embeddings, lengths):
        # sequences: Tensor of shape (batch_size, max_seq_length, embedding_dim)
        # lengths: List of actual lengths of sequences in the batch

        # Masking
        # mask = torch.arange(token_embeddings.size(1)).expand(len(lengths), token_embeddings.size(1))
        # mask = mask.to(token_embeddings.device)
        # mask = mask < torch.tensor(lengths).unsqueeze(1)
        batch_size = token_embeddings.shape[0]
        return token_embeddings[torch.arange(batch_size, device=token_embeddings.device), lengths - self.delta]


class FirstTokenPooling(nn.Module):
    def __init__(self):
        super(FirstTokenPooling, self).__init__()

    def forward(self, token_embeddings, lengths):
        # sequences: Tensor of shape (batch_size, max_seq_length, embedding_dim)
        # lengths: List of actual lengths of sequences in the batch

        # Masking
        batch_size = token_embeddings.shape[0]
        return token_embeddings[torch.arange(batch_size, device=token_embeddings.device), 0]


class FirstAndLastTokenPooling(nn.Module):
    def __init__(self):
        super(FirstAndLastTokenPooling, self).__init__()

    def forward(self, token_embeddings, lengths):
        # sequences: Tensor of shape (batch_size, max_seq_length, embedding_dim)
        # lengths: List of actual lengths of sequences in the batch

        # Masking
        batch_size = token_embeddings.shape[0]
        first_token_embedding = token_embeddings[torch.arange(batch_size, device=token_embeddings.device), 0]
        last_token_embddding = token_embeddings[torch.arange(batch_size, device=token_embeddings.device), lengths]
        return (first_token_embedding + last_token_embddding) / 2


def print_parameters(name, params, params_mask=None, max_out=False):
    import json
    batch_size = params.shape[0]
    output = []
    if len(params.shape) != 2:
        return
    for i in range(batch_size):
        vals = ['{:.2f}'.format(w) for w in params[i]]
        if params_mask is not None:
            vals = vals[0: int(torch.sum(params_mask[i]))]
        str_vals = ', '.join(vals)
        output.append(str_vals)
    if max_out:
        print_0(max(output, key=lambda x:len(x)))
    else:
        print_0(name, json.dumps(output, indent=2))
