import os
import pickle
import logging

import numpy as np
import wandb
import torch
import torch.nn.functional as F
from tqdm import tqdm
import json
import os
import argparse
from sentence_transformers import SentenceTransformer

def extract_off_diagonal(matrix):
    """
    Extract all off-diagonal elements from a square matrix tensor.
    
    Args:
        matrix (torch.Tensor): A square matrix tensor
        
    Returns:
        torch.Tensor: A 1D tensor containing all off-diagonal elements
    """
    # Check if the matrix is square
    if matrix.shape[0] != matrix.shape[1]:
        raise ValueError("Input must be a square matrix")
    
    # Create a mask that is False on diagonal and True elsewhere
    mask = ~torch.eye(matrix.shape[0], dtype=torch.bool, device=matrix.device)
    
    # Use the mask to select off-diagonal elements
    off_diagonal_elements = matrix[mask]
    
    return off_diagonal_elements

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='deberta', help='model name')
parser.add_argument('--corruption_type', type=str, default='gaussian_noise', help='model name')
args = parser.parse_args()
model_name = args.model
corruption_type = args.corruption_type

model = SentenceTransformer("all-mpnet-base-v2")
save_dir = ''
base_dir = ''
for noise_level in [1, 2, 3, 4, 5]:
    file_names = [
        f'{model_name}_{corruption_type}_{noise_level}_ensemble_3__rejection.json',
        f'{model_name}_{corruption_type}_{noise_level}_ensemble_3_seed_9__rejection.json',
        # f'{model_name}_{corruption_type}_{noise_level}_ensemble_3_seed_99__rejection.json',
        f'{model_name}_{corruption_type}_{noise_level}_ensemble_3_seed_999__rejection.json',
        f'{model_name}_{corruption_type}_{noise_level}_ensemble_3_seed_99999__rejection.json',
        f'{model_name}_{corruption_type}_{noise_level}_ensemble_3.json',
        f'{model_name}_{corruption_type}_{noise_level}_ensemble_3_seed_9_.json',
        # f'{model_name}_{corruption_type}_{noise_level}_ensemble_3_seed_99__rejection.json',
        f'{model_name}_{corruption_type}_{noise_level}_ensemble_3_seed_999_.json',
        f'{model_name}_{corruption_type}_{noise_level}_ensemble_3_seed_99999_.json',
    ]
    log = []
    for file_name in file_names:
        try:
            log.append(json.load(open(os.path.join(base_dir, file_name), 'r')))
        except:
            print(f'Error in {file_name}')
    scores = []
    for i in tqdm(range(len(log[0]))):
        outputs = sum(
            (log[L][i]['extra_info'].split('- ')[1:] for L in range(len(log)) if 'extra_info' in log[L][i]), []
        )
        embeddings = model.encode(outputs)
        similarities = model.similarity(embeddings, embeddings)
        score = extract_off_diagonal(similarities).mean().item()
        scores.append(score)
    save_name = f'{model_name}_{corruption_type}_{noise_level}_diversity_score_embd.npy'
    np.save(os.path.join(save_dir, save_name), np.array(scores))