import information_geometry as ig

import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
from collections import defaultdict
from tqdm import trange
import argparse
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


from transformers import AutoProcessor, MetaClip2Model

MODEL_NAME = "facebook/metaclip-2-worldwide-huge-quickgelu"
clip_processor = AutoProcessor.from_pretrained(MODEL_NAME)
clip_model = MetaClip2Model.from_pretrained(MODEL_NAME).to(DEVICE).eval()



##### Setup datasets and vocabularies #####
data_path = "IMAGE_DATA_PATH" # Replace with the actual path where data is stored
concept_dict = {
    "shape": ["circles", "squares", "triangles"],
    "color": ["red", "green", "blue", "yellow"],
    "number": [4],
}

G, image_dataset, combo_dataset = ig.load_color_shape_image(
    data_path, concept_dict, clip_model, clip_processor,
    image_num_images = 2, combo_num_images = 2,
)

original_vocab_list = [f"{sample['color']}_{sample['shape']}" for sample in image_dataset.samples]
original_vocab_list += [f"{sample['color1']}_{sample['shape1']}_{sample['color2']}_{sample['shape2']}" for sample in combo_dataset.samples]
vocab_list = sorted(list(set(original_vocab_list)))
vocab_dict = {vocab: i for i, vocab in enumerate(vocab_list)}

print(f"concept dict: {concept_dict}")
print(f"Number of images in image dataset: {len(image_dataset)}")
print(f"Number of images in combo dataset: {len(combo_dataset)}")
print(f"Original vocab size: {len(original_vocab_list)}")
print(f"Unique vocab size: {len(vocab_list)}")



arg = argparse.ArgumentParser()
arg.add_argument("--target_concept", type=str, default="color")
arg.add_argument("--other_concept", type=str, default="shape")
arg.add_argument("--value0", type=str, default="blue")
arg.add_argument("--value1", type=str, default="red")
args = arg.parse_args()

target_concept = args.target_concept
other_concept = args.other_concept
value0 = args.value0
value1 = args.value1


#### counterfactual mapping #####
print(f"Target concept: {target_concept}, Other concept: {other_concept}, Value0: {value0}, Value1: {value1}")

mapping = {}
for i, vocab in enumerate(vocab_list):
    if value0 in vocab:
        if value1 in vocab:
                continue
        if len(vocab.split('_')) <= 2:
            target_word = vocab.replace(value0, value1)
            mapping[vocab] = target_word
        else:
            target_word = vocab.replace(value0, value1)
            if target_word not in vocab_list:
                if target_concept == "color":
                    parts = vocab.split('_')
                    if parts[0] == value0:
                        parts[0] = value1
                    elif parts[2] == value0:
                        parts[2] = value1
                    target_word = '_'.join(parts)
                else:
                    parts = vocab.split('_')
                    if parts[1] == value0:
                        parts[1] = value1
                    elif parts[3] == value0:
                        parts[3] = value1
                    target_word = '_'.join(parts)
            if target_word in vocab_list:
                mapping[vocab] = target_word

print(f"Number of mapping: {len(mapping)}")




#### Prepare train and test sets #####
prefix_formats = [
    "",
    "a rendering of ",
    "a depiction of ",
    "an illustration of ",
    "a conceptual illustration of ",
    "A rendering of ",
    "A depiction of ",
    "An illustration of ",
    "A conceptual illustration of ",
    "Rendering of ",
    "Depiction of ",
    "Illustration of ",
    "Conceptual illustration of ",
]

def get_train_test(concept_dict, value0, value1, other_concept, prefix_formats, alpha = 0.7):
    single_texts0 = []
    single_texts1 = []
    for pf in prefix_formats:
        if other_concept == "shape":
            for s in concept_dict[other_concept]:
                single_texts0.append(f"{pf}{value0} {s}")
                single_texts1.append(f"{pf}{value1} {s}")
        elif other_concept == "color":
            for c in concept_dict[other_concept]:
                single_texts0.append(f"{pf}{c} {value0}")
                single_texts1.append(f"{pf}{c} {value1}")

    indices = list(range(len(single_texts0)))
    random.seed(100)
    random.shuffle(indices)
    split_idx = int(alpha * len(indices))
    train_indices = indices[:split_idx]
    test_indices = indices[split_idx:]
    
    train_texts0 = [single_texts0[i] for i in train_indices]
    test_texts0 = [single_texts0[i] for i in test_indices]
    train_texts1 = [single_texts1[i] for i in train_indices]
    test_texts1 = [single_texts1[i] for i in test_indices]

    train_primals0 = ig.get_clip_text_embeddings(train_texts0, clip_model, clip_processor)
    train_primals1 = ig.get_clip_text_embeddings(train_texts1, clip_model, clip_processor)
    test_primals0 = ig.get_clip_text_embeddings(test_texts0, clip_model, clip_processor)
    test_primals1 = ig.get_clip_text_embeddings(test_texts1, clip_model, clip_processor)

    return test_texts0, train_primals0, train_primals1, test_primals0, test_primals1


def aggregate_probs(probs, original_vocab_list, vocab_list):
    new_probs = torch.zeros(probs.size(0), len(vocab_list)).to(probs.device)
    for i, vocab in enumerate(vocab_list):
        for j, original_vocab in enumerate(original_vocab_list):
            if vocab == original_vocab:
                new_probs[:, i] += probs[:, j]
    return new_probs


test_texts0, train_primals0, train_primals1, test_primals0, test_primals1 = get_train_test(concept_dict, value0, value1, other_concept, prefix_formats, alpha = 0.7)
train_duals0 = ig.primals_to_duals(train_primals0, G)
train_duals1 = ig.primals_to_duals(train_primals1, G)





#### Compute directions #####
directions = {
    "primal_md": ig.get_MD(train_primals0, train_primals1),
    "dual_md": ig.get_MD(train_duals0, train_duals1),
}
direction_names = list(directions.keys())



#### Steering test primals #####
primals_dict_list = []
for start_primal in tqdm(test_primals0):
    paths = {name: {'e': None, 'm': None} for name in direction_names}
    for name in direction_names:
        direction = directions[name]
        e_path = ig.e_steering(start_primal, direction, G, num_steps = 400, step_size = 0.1,  use_tqdm=False)
        m_path = ig.m_steering(start_primal, direction, G, alpha = 5e-3, num_steps = 3000, step_size = 0.5, use_tqdm=False)
        paths[name]['e'] = e_path
        paths[name]['m'] = m_path
    primals_dict_list.append(paths)





#### Compute metrics along the paths #####
indices0 = [vocab_dict[i] for i in list(set(mapping.keys()))]
indices1 = [vocab_dict[i] for i in list(set(mapping.values()))]

def get_base_target_probs(probs, indices0, indices1):
    probs0 = probs[:,indices0]
    probs1 = probs[:,indices1]
    return probs0.sum(dim = -1), probs1.sum(dim = -1)

def get_off_probs(probs, mapping, vocab_dict):
    off_probs = probs.clone()
    delete_indices = []
    for k, v in mapping.items():
        off_probs[:, vocab_dict[v]] += off_probs[:, vocab_dict[k]]
        off_probs[:, vocab_dict[k]] = 0.0
        delete_indices.append(vocab_dict[k])
    
    keep_indices = [i for i in range(off_probs.size(1)) if i not in delete_indices]
    off_probs = off_probs[:, keep_indices]
    
    return off_probs

def get_kls(probs, offset = 5e-3):
    q = probs[0]
    forward_kl = torch.sum(q * (torch.log(q + offset) - torch.log(probs + offset)), dim=-1)
    return forward_kl

def get_rank_diff(probs, use_topp = True):
    if use_topp:
        seq = np.linspace(0, len(probs) - 1, 20, dtype=int).tolist()
        topp_indices = []
        for i in seq:
            q = probs[i]
            sorted_probs, sorted_indices = torch.sort(q, dim=-1, descending=True)
            cumsum_probs = sorted_probs.cumsum(dim=-1)
            cum_sort = cumsum_probs - sorted_probs
            topp_indices.extend(sorted_indices[cum_sort < 0.999].tolist())
        topp_indices = list(set(topp_indices))

        probs = probs[:, topp_indices]
        if probs.sum(dim =-1).min() < 0.99:
            print(probs.sum(dim =-1).min())
            print("Sum of selected probabilities is less than 0.99")
        sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
    else:
        q = probs[0]
        sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)

    ranks = torch.zeros_like(sorted_indices, dtype=torch.float)
    ranks.scatter_(-1, sorted_indices, 
                torch.arange(probs.size(-1),
                                dtype=torch.float).expand_as(sorted_indices).to(probs.device))
    ranks += 1 
    
    rank_diff = (1/ranks - 1/ranks[0]).abs()
    if use_topp:
        weight = q[topp_indices]
        weight = weight / weight.sum()
        return rank_diff @ weight
    else:
        return rank_diff @ q

def get_cos(probs, G, direction):
    duals = probs @ G
    dual_diff = duals[1:] - duals[:-1]
    dual_diff = torch.cat([dual_diff, dual_diff[-1].unsqueeze(0)], dim=0)
    normalized_dual_diff = dual_diff / (dual_diff.norm(dim=-1, keepdim=True) + 1e-16)
    normalized_direction = direction / (direction.norm() + 1e-16)
    return normalized_dual_diff @ normalized_direction


all_list = {dir: defaultdict(list) for dir in directions.keys()}

for dir_name in directions.keys():
    for method in ['e', 'm']:
        for i in trange(len(primals_dict_list)):
            path = primals_dict_list[i][dir_name][method]

            probs = path @ G.T
            probs = F.softmax(probs, dim=-1)
            cos = get_cos(probs, G, directions[dir_name])


            ### For CLIP ###
            probs = aggregate_probs(probs, original_vocab_list, vocab_list)

            probs0, probs1 = get_base_target_probs(probs, indices0, indices1)
            cf_sum = probs0 + probs1
            ratio = probs1 / (probs0 + probs1 + 1e-10)
            mask = ratio < 0.9999


            off_prob = get_off_probs(probs, mapping, vocab_dict)
            fkl = get_kls(off_prob, offset = 1e-6)
            rank_diff = get_rank_diff(off_prob, use_topp = False)
            

            all_list[dir_name][method].append({
                "probs0": probs0[mask].cpu(),
                "probs1": probs1[mask].cpu(),
                "sum": cf_sum[mask].cpu(),
                "ratio": ratio[mask].cpu(),
                "fkl": fkl[mask].cpu(),
                "rank_diff": rank_diff[mask].cpu(),
                "cos": cos[mask].cpu()
            })


base_path = "BASE_PATH" # Replace with the actual base path where data is stored
torch.save(all_list, os.path.join(base_path, f"{target_concept}_{value0}_to_{value1}.pt"))
torch.save(primals_dict_list, os.path.join(base_path, f"test_steering_paths_{target_concept}_{value0}_to_{value1}.pt"))

with open(os.path.join(base_path, f"test_texts_{target_concept}_{value0}_to_{value1}.txt"), "w") as f:
    for text in test_texts0:
        f.write(text + "\n")