import sys

import wandb
from time import sleep
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

class EmbeddingEnsemble(nn.Module):
    def __init__(self, models, fusion='mean', weights=None):
        super().__init__()
        self.models = nn.ModuleList(models)
        self.fusion = fusion
        self.weights = weights

        if fusion == 'weighted_mean':
            assert weights is not None and len(weights) == len(models), "Need weights for weighted_mean fusion."
            weights_tensor = torch.tensor(weights, dtype=torch.float32)
            self.register_buffer('weights_tensor', weights_tensor / weights_tensor.sum())

    def forward(self, vision, output_normalize=True):
        embeddings = []
        for model in self.models:
            emb = model(vision, output_normalize=output_normalize)
            embeddings.append(emb)

        stacked = torch.stack(embeddings, dim=0)  # (num_models, batch, dim)

        if self.fusion == 'mean':
            fused = stacked.mean(dim=0)
        elif self.fusion == 'max':
            fused = stacked.max(dim=0).values
        elif self.fusion == 'weighted_mean':
            # weights shape: (num_models,)
            fused = torch.einsum('mbd,m->bd', stacked, self.weights_tensor)
        else:
            raise ValueError(f"Unsupported fusion method: {self.fusion}")

        return fused
# ensemble_model = EmbeddingEnsemble(models=models, fusion='weighted')  # 可学习权重
# # 或使用固定权重
# ensemble_model = EmbeddingEnsemble(models=models, fusion='weighted', weights=[0.5, 0.3, 0.2])
# # 或均值
# ensemble_model = EmbeddingEnsemble(models=models, fusion='mean')

def load_vision_components(model, load_path):
    state_dict = torch.load(load_path)
    model.vision_model.load_state_dict(state_dict['vision_model'])
    if 'visual_projection' in state_dict:
        model.visual_projection.load_state_dict(state_dict['visual_projection'])

def init_wandb(project_name, model_name, config, **wandb_kwargs):
    os.environ['WANDB__SERVICE_WAIT'] = '300'
    while True:
        try:
            wandb_run = wandb.init(
                project=project_name, name=model_name, save_code=True,
                config=config, **wandb_kwargs, mode="offline" ,
                )
            break
        except Exception as e:
            print('wandb connection error', file=sys.stderr)
            print(f'error: {e}', file=sys.stderr)
            sleep(1)
            print('retrying..', file=sys.stderr)
    return wandb_run

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise ValueError

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)