import os
import json
import numpy as np
import torch
from tqdm import trange
import information_geometry as ig
import argparse
import torch.nn.functional as F
from collections import defaultdict

MODEL_NAME = "google/gemma-3-4b-pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

model, tokenizer, vocab_dict, vocab_list, G = ig.load_model_and_vocab(MODEL_NAME, device=DEVICE)

arg = argparse.ArgumentParser()
arg.add_argument("--concept_name", type=str, default="verb_en_fr")
args = arg.parse_args()

concept_name = args.concept_name


base_path = "BASE_PATH" # Replace with the actual base path where data is stored
concept_path = os.path.join(base_path, concept_name)
primals_dict_list = torch.load(os.path.join(concept_path, "test_steering_paths.pt"))

with open(os.path.join(f"data/mapping_{concept_name}.json"), "r") as f:
    mapping = json.load(f)
mapping = ig.get_clean_mapping(mapping, vocab_dict)

directions = torch.load(os.path.join(concept_path, "directions.pt"))

indices0 = [vocab_dict[i] for i in list(set(mapping.keys()))]
indices1 = [vocab_dict[i] for i in list(set(mapping.values()))]

def get_base_target_probs(probs, indices0, indices1):
    probs0 = probs[:,indices0]
    probs1 = probs[:,indices1]
    return probs0.sum(dim = -1), probs1.sum(dim = -1)

def get_off_probs(probs, mapping, vocab_dict):
    off_probs = probs.clone()
    delete_indices = []
    for k, v in mapping.items():
        off_probs[:, vocab_dict[v]] += off_probs[:, vocab_dict[k]]
        off_probs[:, vocab_dict[k]] = 0.0
        delete_indices.append(vocab_dict[k])
    
    keep_indices = [i for i in range(off_probs.size(1)) if i not in delete_indices]
    off_probs = off_probs[:, keep_indices]
    
    return off_probs

def get_kls(probs, offset = 5e-3):
    q = probs[0]
    forward_kl = torch.sum(q * (torch.log(q + offset) - torch.log(probs + offset)), dim=-1)
    return forward_kl

def get_rank_diff(probs):
    seq = np.linspace(0, len(probs) - 1, 20, dtype=int).tolist()
    topp_indices = []
    for i in seq:
        q = probs[i]
        sorted_probs, sorted_indices = torch.sort(q, dim=-1, descending=True)
        cumsum_probs = sorted_probs.cumsum(dim=-1)
        cum_sort = cumsum_probs - sorted_probs
        topp_indices.extend(sorted_indices[cum_sort < 0.999].tolist())
    topp_indices = list(set(topp_indices))

    probs = probs[:, topp_indices]
    if probs.sum(dim =-1).min() < 0.99:
        print(probs.sum(dim =-1).min())
        print("Sum of selected probabilities is less than 0.99")
    sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)

    ranks = torch.zeros_like(sorted_indices, dtype=torch.float)
    ranks.scatter_(-1, sorted_indices, 
                   torch.arange(probs.size(-1),
                                dtype=torch.float).expand_as(sorted_indices).to(probs.device))
    ranks += 1 
    
    rank_diff = (1/ranks - 1/ranks[0]).abs()
    weight = q[topp_indices]
    weight = weight / weight.sum()
    return rank_diff @ weight

def get_cos(probs, G, direction):
    duals = probs @ G
    dual_diff = duals[1:] - duals[:-1]
    dual_diff = torch.cat([dual_diff, dual_diff[-1].unsqueeze(0)], dim=0)
    normalized_dual_diff = dual_diff / (dual_diff.norm(dim=-1, keepdim=True) + 1e-16)
    normalized_direction = direction / (direction.norm() + 1e-16)
    return normalized_dual_diff @ normalized_direction




all_list = {dir: defaultdict(list) for dir in directions.keys()}

for dir_name in directions.keys():
    for type in ['e', 'm']:
        for i in trange(len(primals_dict_list)):
            path = primals_dict_list[i][dir_name][type]

            probs = path @ G.T
            probs = F.softmax(probs, dim=-1)

            probs0, probs1 = get_base_target_probs(probs, indices0, indices1)
            cf_sum = probs0 + probs1
            ratio = probs1 / (probs0 + probs1 + 1e-10)
            mask = ratio < 0.9999

            off_prob = get_off_probs(probs, mapping, vocab_dict)
            fkl = get_kls(off_prob, offset = 1e-6)
            rank_diff = get_rank_diff(off_prob)
            cos = get_cos(probs, G, directions[dir_name])

            all_list[dir_name][type].append({
                "probs0": probs0[mask].cpu(),
                "probs1": probs1[mask].cpu(),
                "sum": cf_sum[mask].cpu(),
                "ratio": ratio[mask].cpu(),
                "fkl": fkl[mask].cpu(),
                "rank_diff": rank_diff[mask].cpu(),
                "cos": cos[mask].cpu()
            })

torch.save(all_list, os.path.join(concept_path, "test_steering_metrics.pt"))
