import copy
import open_clip
import pandas as pd
import json
import torch
import argparse
import os
import numpy as np

def class_names(path, dataset):
    class_names = []
    class_indices = []
    if dataset == 'Replica':
        with open(path, 'r') as file:
            data = json.load(file)
            for i in range(len(data['classes'])):
                class_names.append('a photo of ' + data['classes'][i]['name'])
                class_indices.append(int(data['classes'][i]['id']) - 1)
    elif dataset == 'ScanNet':
        with open(path, 'r') as file:
            data = json.load(file)
            for key in data.keys():
                class_indices.append(key)
                if data[key] != 'picture':
                    class_names.append('a photo of ' + data[key])
                else:
                    class_names.append('a photo of a picture on wall')
    return class_names, class_indices

def clip_text(text, model):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)
    tokenizer = open_clip.get_tokenizer('ViT-H-14')  # same as your original
    text = tokenizer(text).to(device)
    with torch.no_grad():
        embeddings = model.encode_text(text)
    return embeddings  # [C, D]

def combine_embedding(components, mode, alpha_h, alpha_l, alpha_o, alpha_m):
    """
    components: dict with keys ['E_s','E_l','E_h','E_hide','E_mask']
    mode: string specifying ablation type
      - 'default'    : E_h*alpha_h + E_l*alpha_l + E_s - alpha_o*E_hide + alpha_m*E_mask
      - 's_only'     : E_s
      - 'mask_only'  : E_mask
      - 's_l'        : E_s + alpha_l*E_l
      - 's_l_h'      : E_s + alpha_l*E_l + alpha_h*E_h
      - 'no_hide'    : E_h*alpha_h + E_l*alpha_l + E_s + alpha_m*E_mask   (alpha_o=0)
      - 'no_mask'    : E_h*alpha_h + E_l*alpha_l + E_s - alpha_o*E_hide   (alpha_m=0)
      - 'manual'     : use alpha_* as provided
    Returns: tensor [D]
    """
    device = components['E_s'].device
    E_s    = components['E_s']
    E_l    = components['E_l']
    E_h    = components['E_h']
    E_hide = components['E_hide']
    E_mask = components['E_mask']

    if mode == 's_only':
        return E_s
    elif mode == 'mask_only':
        return E_mask
    elif mode == 's_l':
        return E_s + alpha_l * E_l
    elif mode == 's_l_h':
        return E_s + alpha_l * E_l + alpha_h * E_h
    elif mode == 'no_hide':
        return alpha_h * E_h + alpha_l * E_l + E_s + alpha_m * E_mask
    elif mode == 'no_mask':
        return alpha_h * E_h + alpha_l * E_l + E_s - alpha_o * E_hide
    elif mode in ['default', 'manual']:
        # default == manual with alpha_* from config or CLI
        return alpha_h * E_h + alpha_l * E_l + E_s - alpha_o * E_hide + alpha_m * E_mask
    else:
        raise ValueError(f"Unknown mode: {mode}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--clip_model", type=str, default='EVA02-L-14-336')
    parser.add_argument("--path", type=str, default='')
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--scene", type=str)
    parser.add_argument("--mode", type=str, default='default',
                        help="embedding ablation mode: default | s_only | mask_only | s_l | s_l_h | no_hide | no_mask | manual")
    # For sensitivity / manual weights
    parser.add_argument("--alpha_h", type=float, default=None)
    parser.add_argument("--alpha_l", type=float, default=None)
    parser.add_argument("--alpha_o", type=float, default=None)
    parser.add_argument("--alpha_m", type=float, default=None)
    args = parser.parse_args()

    path    = args.path
    dataset = args.dataset
    scene   = args.scene
    mode    = args.mode

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # ----- Load CLIP model -----
    model, _, _ = open_clip.create_model_and_transforms(args.clip_model, pretrained=None)
    model.load_state_dict(torch.load(path + '/models/open_clip_pytorch_model.bin'), strict=False)

    # ----- Load points->ids -----
    # df_points_to_id = pd.read_csv(path + '/embeddings/' + scene + '_points_to_ids.csv')
    df_points_to_id = pd.read_csv(path + '/embeddings/' + scene + '_points_to_ids_ctx.csv')

    # ----- Load embeddings with components -----
    with open(path + '/embeddings/' + scene + '_ids_to_embeddings_ctx.json', 'r') as f:
        df_ids_to_embeddings = json.load(f)

    # ----- Load class names & indices -----
    if dataset == 'Replica':
        classes, class_indices = class_names(path + '/dataset/Replica-data/info_semantic.json', dataset)
    else:
        classes, class_indices = class_names(path + '/dataset/ScanNet/classes.json', dataset)

    # ----- Get text embeddings -----
    class_embeddings = clip_text(classes, model).to(device)  # [C, D]

    # Default alphas from config.yaml if not provided
    if args.alpha_h is None or args.alpha_l is None or args.alpha_o is None or args.alpha_m is None:
        # read from YAML config to keep same defaults as embedding_generation
        import yaml
        if dataset == 'Replica':
            with open(path + "/core_configs/config_Replica.yaml", "r") as f:
                config = yaml.safe_load(f)
        else:
            with open(path + "/core_configs/config_ScanNet.yaml", "r") as f:
                config = yaml.safe_load(f)
        alpha_h = config['alpha_h'] if args.alpha_h is None else args.alpha_h
        alpha_l = config['alpha_l'] if args.alpha_l is None else args.alpha_l
        alpha_o = config['alpha_o'] if args.alpha_o is None else args.alpha_o
        alpha_m = config['alpha_m'] if args.alpha_m is None else args.alpha_m
    else:
        alpha_h = args.alpha_h
        alpha_l = args.alpha_l
        alpha_o = args.alpha_o
        alpha_m = args.alpha_m

    alpha_h = float(alpha_h)
    alpha_l = float(alpha_l)
    alpha_o = float(alpha_o)
    alpha_m = float(alpha_m)

    labels = []

    for i in range(len(df_points_to_id)):
        my_id = df_points_to_id.at[i, 'Object id']
        emb_entry = df_ids_to_embeddings[str(my_id)]

        # base components as tensors [D]
        E_s    = torch.tensor(emb_entry['E_s'],    dtype=torch.float32, device=device)
        E_l    = torch.tensor(emb_entry['E_l'],    dtype=torch.float32, device=device)
        E_h    = torch.tensor(emb_entry['E_h'],    dtype=torch.float32, device=device)
        E_hide = torch.tensor(emb_entry['E_hide'], dtype=torch.float32, device=device)
        E_mask = torch.tensor(emb_entry['E_mask'], dtype=torch.float32, device=device)

        comp = {
            'E_s':    E_s,
            'E_l':    E_l,
            'E_h':    E_h,
            'E_hide': E_hide,
            'E_mask': E_mask,
        }

        emb_vec = combine_embedding(comp, mode, alpha_h, alpha_l, alpha_o, alpha_m)  # [D]

        # similarity with classes
        class_probs = emb_vec @ class_embeddings.T  # [C]
        class_index = torch.argmax(class_probs).item()
        label = class_indices[class_index]
        labels.append(int(label))

    df_points_to_id['labels'] = labels

    os.makedirs(path + '/predicted_labels1/', exist_ok=True)
    out_csv = path + '/predicted_labels1/' + scene + f'_predicted_labels_{mode}.csv'
    df_points_to_id.to_csv(out_csv, index=False)
    print(f"Saved ablation predictions to: {out_csv}")

if __name__ == '__main__':
    main()
