import os
# set up logging
import logging

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
import time
import csv
import pdb
import pickle
from typing import List, Tuple, Union, Optional, Dict, Any, Literal

from tqdm import tqdm
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import argparse
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

from bertviz import head_view, model_view
from transformers import AutoTokenizer, LlamaForCausalLM

from src.dataset import ProbeDataLoader, LMDataloader, GPTDataloaderForInference, ObjectLocationProbeDataLoader, \
    BinaryProbeDataLoader
from src.model import T5ForProbing, GPTForProbing, LlamaForProbing
from src.probe_trainer import Trainer, TrainerConfig
from src.probe_model import BatteryProbeClassification, ObjectLocationProbeClassification, \
    BatteryProbeClassificationTwoLayer


_MAX_SOURCE_TEXT_LENGTH = {
    "t5": 512,
    "gpt": 512,
    "llama": 2048,
    "Llama-3.1-8B": 4096,
    "Llama-3.1-405B": 2048,
    "Llama-3.2-1B": 2048,
    "CodeLlama-13b-hf": 4096,
}

_MAX_TARGET_TEXT_LENGTH = 100

_INPUT_DIMENSIONS = {
    "t5": 768,
    "gpt": 1600,
    "llama": 5120,
    "Llama-3.1-8B": 4096,
    "Llama-3.1-405B": 16384,
    "Llama-3.2-1B": 2048,
    "CodeLlama-13b-hf": 5120,
}

# make deterministic
torch.manual_seed(0)


def mean_of_nonzero(tensor, dim=None, keepdim=False):
    """Calculates the mean of non-zero elements in a PyTorch tensor along specified dimensions.

    Args:
    tensor: The input tensor.
    dim: The dimension(s) to reduce. If None, calculates the mean of all non-zero elements.
    keepdim: Whether the output tensor has dim retained or not.

    Returns:
    The mean of the non-zero elements, or NaN if there are no non-zero elements along a dimension.
    """
    non_zero_mask = tensor.clone() != 0
    non_zero_elements = torch.masked.masked_tensor(tensor.clone(), non_zero_mask) # Apply the mask

    if dim is None:
        if non_zero_elements.count() == 0:
            return torch.tensor(float("nan"))
        return non_zero_elements.mean()
    else:
        if isinstance(dim, int):
            dim = [dim]
        return non_zero_elements.mean(dim=dim, keepdim=keepdim)



def aggregate_attentions(attn: Tuple[torch.Tensor], spans: List[Tuple[int, int]], aggregation: Literal["mean", "sum", "max"], normalize=True):
    """ aggregate attention weights across tokens in each span
    naive looping for now
    Args:
          attn (torch.Tensor [bs, head, pos, pos]): attention weights

    Returns:
        (torch.tensor [bs, head, span, span]): aggregated attention weights
    """
    new_attn = []
    bs, d_head, n_tokens, _ = attn[0].shape
    for layer, attn_layer in enumerate(attn):
        attn_layer = attn_layer.to("cpu")
        new_attn_layer = torch.zeros((bs, d_head, len(spans), len(spans)))
        for i, (start_i, end_i) in enumerate(spans):
            for j, (start_j, end_j) in enumerate(spans):
                if j > i:
                    continue

                # Extract attention weights for span i to span j
                span_i_indices = list(range(start_i, min(end_i, n_tokens-1)+1))
                span_j_indices = list(range(start_j, min(end_j, n_tokens-1)+1))
                try:
                    span_attention = attn_layer[:,:, span_i_indices][:,:, :, span_j_indices]
                except IndexError:
                    print(f"{attn_layer.shape=}")
                    print(f"{span_i_indices=}")
                    print(f"{span_j_indices=}")
                    exit()
                # Aggregate (e.g., average) the attention weights
                if aggregation == "mean":
                    agg_score = mean_of_nonzero(span_attention, dim=[2,3])
                elif aggregation == "sum":
                    agg_score = torch.sum(span_attention, dim=[2,3])
                elif aggregation == "max":
                    agg_score = torch.max(span_attention, dim=[2,3])
                else:
                    raise NotImplementedError

                new_attn_layer[:, :, i, j] = agg_score

        # optionally normalize across the last dimension (target token)
        if normalize:
            new_attn_layer = F.normalize(new_attn_layer, p=1, dim=-1)

        new_attn.append(new_attn_layer)

    return new_attn

def get_correct_attention_heads(attn: Tuple[torch.Tensor], spans: List[Tuple[int, int]], correct_span_indices: List[int]) -> torch.Tensor:
    """
    Given attention weights across layers, span information, find out which head/layer is attending to the correct
    operation phrase
    """
    assert attn[0].shape[0] == 1, "batch size can only be 1 for attention weights"
    n_layer, n_head, n_tokens = len(attn), attn[0].shape[1], attn[0].shape[-1]
    activating_heads = torch.zeros(len(correct_span_indices), n_layer*n_head)
    for i_layer, attn_layer in enumerate(attn):
        attn_layer = attn_layer.to("cpu")
        span_attn = torch.zeros((n_head, len(spans)))

        # first let's collect last token's attention on all previous tokens
        for i_span, (start, end) in enumerate(spans):
            span_attn[:, i_span] = attn_layer[0, :, -1, start: min(end, n_tokens-1)+1].sum(-1)

        # if the weight for any of the correct span is within top-2 (top1 is usually phrase0)
        # then we count this as a head that is attending to the right place
        # but we keep them as separate label for differentiation
        top_spans = torch.topk(span_attn, k=2, dim=-1) # descending order
        pdb.set_trace()
        for i_head in range(n_head):
            for i_label, correct_span_index in enumerate(correct_span_indices):
                if correct_span_index in top_spans[i_head]:
                    idx = i_label*n_head + i_head
                    activating_heads[i_label, idx] = 1


    return activating_heads



def main():
    parser = argparse.ArgumentParser(description='Train classification network')
    parser.add_argument("--model_type",
                        required=True,
                        choices=_INPUT_DIMENSIONS.keys(),
                        help=f"{_INPUT_DIMENSIONS.keys()} supported.")
    parser.add_argument("--dataset_path",
                        required=True,
                        type=str)
    parser.add_argument("--dataset_subset",
                        dest='dataset_subset',
                        action='store_true'
                        )
    parser.add_argument("--model_path",
                        required=False,
                        default=None,
                        type=str)
    parser.add_argument('--checkpoint_root',
                        default="./probe_checkpoints", type=str)
    parser.add_argument(
        "--object_vocabulary_file",
        type=str,
        default="data/objects_with_bnc_frequency.csv",
        help='Path to a .csv file with a string field "object_names".')
    parser.add_argument('--layer',
                        required=True,
                        default=-1,
                        type=int)
    parser.add_argument('--epo',
                        default=16,
                        type=int)
    parser.add_argument('--condition_on',
                        choices=["box", "period", "the", "number", "contains"],
                        type=str,
                        dest='condition_on',
                        default='number')
    parser.add_argument('--max_train_data',
                        type=int,
                        default=None)
    parser.add_argument('--max_test_data',
                        type=int,
                        default=None)
    parser.add_argument('--num_prior_state',
                        default=-1,
                        type=int)
    parser.add_argument('--exclude_empty',
                        dest='exclude_empty',
                        action='store_true')


    parser.add_argument('--model_representation_path',
                        default=None,
                        type=str)

    parser.add_argument('--save_model_representation',
                        dest="save_model_representation",
                        action="store_true")

    parser.add_argument('--load_model_representation',
                        dest="load_model_representation",
                        action="store_true")

    parser.add_argument('--include_prompt',
                        dest="include_prompt",
                        action="store_true")

    # distributed inference ( for caching embedding)
    parser.add_argument('--distributed',
                        dest="distributed",
                        action="store_true")
    parser.add_argument("--local-rank", "--local_rank", type=int)

    args, _ = parser.parse_known_args()

    if (args.condition_on not in ["number", "contains"]) and args.model_type == 't5':
        raise ValueError("--condition_on must be set to 'number' or 'contains' when training a probe on T5.")

    if args.exclude_empty and args.condition_on not in ["contains", "the"]:
        raise ValueError("--exclude_empty can only be used with --condition_on 'contains' or 'the'")

    if args.condition_on in ["contains", "the"] and not args.exclude_empty:
        raise ValueError("--condition_on 'contains' or 'the' can only be used with --exclude_empty")

    if args.save_model_representation and args.model_representation_path is None:
        raise ValueError("--save_model_representation requires --model_representation_path to be set")

    if args.load_model_representation and args.model_representation_path is None:
        raise ValueError("--load_model_representation requires --model_representation_path to be set")

    if args.load_model_representation and args.save_model_representation:
        raise ValueError("--load_model_representation and --save_model_representation cannot be used together")

    if args.max_train_data is not None:
        assert args.max_train_data % 7 == 0, "number of data points must be divisible by 7"

    if args.max_test_data is not None:
        assert args.max_test_data % 7 == 0, "number of data points must be divisible by 7"

    if args.dataset_subset:
        assert os.path.exists(os.path.join(args.dataset_path,
                                           f'test-subsample-states-{"t5" if args.model_type == "t5" else "gpt"}.jsonl'))
        if "movecontent" in args.dataset_path.lower() or "move_content" in args.dataset_path.lower():
            assert os.path.exists(os.path.join(args.dataset_path, f'train-subsample-states-mask.p'))

    folder_name = f"probing/state"

    if args.exclude_empty:
        folder_name = folder_name + "_exclude_empty"

    print(f"Running experiment for {folder_name}")
    print("[Data]: Reading data...\n")

    # device = 'cuda' if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else 'cpu'
    device = "cpu"  # debug
    print(f'Training Device: {device}')

    # Load data
    data_type = "t5" if args.model_type == "t5" else "gpt"
    dataset_path_train = os.path.join(args.dataset_path,
                                      f'train{"-subsample-states" if args.dataset_subset else ""}-{data_type}.jsonl')
    print("Train dataset:", dataset_path_train)
    dataset_path_test = os.path.join(args.dataset_path,
                                     f'test{"-subsample-states" if args.dataset_subset else ""}-{data_type}.jsonl')

    train_df = pd.read_json(dataset_path_train, orient='records', lines=True)
    test_df = pd.read_json(dataset_path_test, orient='records', lines=True)

    if args.model_type == "t5":
        train_df = train_df[["sentence_masked", "masked_content"]]
        test_df = test_df[["sentence_masked", "masked_content"]]

    # Load object names
    object_map = {}
    object_list = []
    with open(args.object_vocabulary_file, encoding="utf-8-sig") as f:
        reader = csv.DictReader(f)
        for i, row in enumerate(reader):
            object_map[row["object_name"]] = i
            object_list.append(row["object_name"])

    act_container_train = []
    act_all_container_train = []
    act_container_test = []
    act_all_container_test = []

    if args.load_model_representation:
        # load pre-computed representations
        train_rep_path = os.path.join(args.model_representation_path, "representations_train.p")
        test_rep_path = os.path.join(args.model_representation_path, "representations_test.p")

        with open(train_rep_path, "rb") as rep_f:
            act_all_container_train = pickle.load(rep_f)

        for act in act_all_container_train:
            act_container_train.append(act[args.layer - 1])

        act_all_container_train.clear()

        with open(test_rep_path, "rb") as rep_f:
            act_all_container_test = pickle.load(rep_f)

        for act in act_all_container_test:
            act_container_test.append(act[args.layer - 1])

        act_all_container_test.clear()

    else:
        # set up distributed inference
        if args.distributed:
            rank = int(os.environ["RANK"])
            print(f"{rank=}")
            inference_device = torch.device(f"cuda:{rank}")
            torch.cuda.set_device(inference_device)  # https://github.com/pytorch/pytorch/issues/146767
            torch.distributed.init_process_group("nccl", device_id=inference_device)
            my_rank = torch.distributed.get_rank()
        else:
            inference_device = device

        # Load T5 model to compute representations
        if args.model_type == "t5":
            model = T5ForProbing.from_pretrained(args.model_path)
        elif args.model_type == "gpt":
            model = GPTForProbing.from_pretrained(args.model_path)
        elif "llama" in args.model_type.lower():
            if args.distributed:
                model = LlamaForCausalLM.from_pretrained(args.model_path, tp_plan="auto")  # ,device_map="auto",
            else:
                model = LlamaForCausalLM.from_pretrained(args.model_path).to(device)

        # Set probe layer (1-indexed)
        model.probe_layer = args.layer
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        inference_device = model.device
        model.eval()

        # initialze LM test dataset
        if args.model_type == "t5":
            train_dataset = LMDataloader(train_df, tokenizer, _MAX_SOURCE_TEXT_LENGTH[args.model_type],
                                         _MAX_TARGET_TEXT_LENGTH, "sentence_masked", "masked_content",
                                         include_empty=not args.exclude_empty, min_prev_objects=args.condition_on_obj)
            test_dataset = LMDataloader(test_df, tokenizer, _MAX_SOURCE_TEXT_LENGTH[args.model_type],
                                        _MAX_TARGET_TEXT_LENGTH, "sentence_masked", "masked_content",
                                        include_empty=not args.exclude_empty, min_prev_objects=args.condition_on_obj)
        else:  # if args.model_type in ["gpt", "llama"]:
            train_dataset = GPTDataloaderForInference(train_df, tokenizer,
                                                      max_length=_MAX_SOURCE_TEXT_LENGTH[args.model_type],
                                                      include_empty=not args.exclude_empty,
                                                      condition_on=args.condition_on,
                                                      # min_prev_objects=args.condition_on_obj,
                                                      include_prompt=args.include_prompt,
                                                      return_span=True)
            test_dataset = GPTDataloaderForInference(test_df, tokenizer,
                                                     max_length=_MAX_SOURCE_TEXT_LENGTH[args.model_type],
                                                     include_empty=not args.exclude_empty,
                                                     condition_on=args.condition_on,
                                                     # min_prev_objects=args.condition_on_obj,
                                                     include_prompt=args.include_prompt,
                                                     return_span=True)

        # truncate dataset if needed
        if args.max_train_data is not None:
            train_dataset = torch.utils.data.Subset(train_dataset, range(args.max_train_data))
        if args.max_test_data is not None:
            test_dataset = torch.utils.data.Subset(test_dataset, range(args.max_test_data))

        loader_train = DataLoader(train_dataset, shuffle=False, pin_memory=True, batch_size=1, num_workers=1)
        loader_test = DataLoader(test_dataset, shuffle=False, pin_memory=True, batch_size=1, num_workers=1)

        # compute hidden representations
        if args.model_type == "t5":

            the_token = torch.tensor(8, dtype=torch.long)

            for data in tqdm(loader_train, total=len(loader_train)):
                token_idx = 0
                the_pos = 0

                if args.condition_on == "contains":
                    token_idx = 1
                elif args.condition_on == "the":
                    token_idx = 2
                    the_pos = 1

                while args.condition_on_obj >= the_pos:
                    # 'the' has the token index 8 for T5
                    token_idx = list(data['target_ids'][0]).index(the_token, token_idx + 1)
                    the_pos += 1

                decoder_input = data['target_ids'][:, 0:(token_idx + 1)].to(device, dtype=torch.long)
                ids = data['source_ids'].to(device, dtype=torch.long)
                mask = data['source_mask'].to(device, dtype=torch.long)

                # print(tokenizer.convert_ids_to_tokens(decoder_input[0]))

                if args.save_model_representation:
                    act = model(input_ids=ids, attention_mask=mask, decoder_input_ids=decoder_input,
                                return_all_layers=True)  # representation at first (=mask) token
                    act_container_train.append(act[args.layer - 1][0, -1, :].detach().cpu())
                    act_all_container_train.append([a[0, -1, :].detach().cpu() for a in act])
                else:
                    # forward function automatically outputs the representation at `layer`
                    act = model(input_ids=ids, attention_mask=mask, decoder_input_ids=decoder_input)[0, -1,
                          :].detach().cpu()  # representation at first (=mask) token
                    act_container_train.append(act)

                # Activate this for debugging on a handful of examples
                # if len(act_container_train) == 70:
                #     break

            for data in tqdm(loader_test, total=len(loader_test)):

                token_idx = 0
                the_pos = 0

                if args.condition_on == "contains":
                    token_idx = 1
                elif args.condition_on == "the":
                    token_idx = 2
                    the_pos = 1

                while args.condition_on_obj >= the_pos:
                    # 'the' has the token index 8 for T5
                    token_idx = list(data['target_ids'][0]).index(the_token, token_idx + 1)
                    the_pos += 1

                decoder_input = data['target_ids'][:, 0:(token_idx + 1)].to(device, dtype=torch.long)
                ids = data['source_ids'].to(device, dtype=torch.long)
                mask = data['source_mask'].to(device, dtype=torch.long)

                # print(tokenizer.convert_ids_to_tokens(labels[0][:6]))

                if args.save_model_representation:
                    act = model(input_ids=ids, attention_mask=mask, decoder_input_ids=decoder_input,
                                return_all_layers=True)  # representation at first (=mask) token
                    act_container_test.append(act[args.layer - 1][0, -1, :].detach().cpu())
                    act_all_container_test.append([a[0, -1, :].detach().cpu() for a in act])
                else:
                    # forward function automatically outputs the representation at `layer`
                    act = model(input_ids=ids, attention_mask=mask, decoder_input_ids=decoder_input)[0, -1,
                          :].detach().cpu()  # representation at first (=mask) token
                    act_container_test.append(act)

                # Activate this for debugging on a handful of examples
                # if len(act_container_test) == 70:
                #     break

        else:  # args.model_type in ["gpt", "llama"]:
            end_idx = None
            # I've checked that " n" between 0 and 7 are all single tokens.
            # But this is probably a patchy solution if box num >= 8
            if args.condition_on == "box":
                end_idx = -3  # originally -1
            # "Box" is one token, "Box 3" is 3 tokens
            elif args.condition_on == "period":
                end_idx = -4  # originally -2, our data ends with "contains", so one extra word, also 'Box 3' is 3 tokens (space is one)
            saved_cnt = 0
            for i, data in enumerate(tqdm(loader_train, total=len(loader_train))):
                # if data['numops'] < 3:
                #     continue
                # else:
                #     saved_cnt += 1
                # if saved_cnt > 5:
                #     exit()
                ids = data['prefix_ids'].to(inference_device, dtype=torch.long)
                mask = data['prefix_attn_masks'].to(inference_device, dtype=torch.long)
                if end_idx is not None:
                    ids = ids[0][:end_idx].unsqueeze(0)
                    mask = mask[0][:end_idx].unsqueeze(0)
                if args.save_model_representation:

                    # out[0] (logit): (bs, pos, vocab).
                    # out[1] (hidden): layer X tuple(2), each (bs, d_head, pos, d_key_value_head)
                    # out[2] (attn): layer X (bs, d_head, pos, pos)
                    output = model(input_ids=ids, attention_mask=mask, return_all_layers=True, output_attentions=True)
                    attn = output[2]
                    tokens = tokenizer.convert_ids_to_tokens(ids[0])

                    # not really informative because the sentence is too long
                    # html_view = model_view(attn, tokens, html_action='return')
                    # with open(f"plots/model_view_{args.model_type}.html", 'w') as file:
                    #     file.write(html_view.data)

                    # aggregate attention weights across tokens in the same phrase
                    # phrase_tokens = [t[0] for t in data["span_tokens"]]
                    # phrase_attn = aggregate_attentions(attn, data["span"], aggregation="mean")

                    # html_view = head_view(phrase_attn, phrase_tokens, html_action='return')
                    # os.makedirs("plots/attention_views", exist_ok=True)
                    # with open(f"plots/attention_views/agg_head_view_{args.model_type}_example_{i}.html", 'w') as file:
                    #     file.write(html_view.data)
                    # html_view = model_view(phrase_attn, phrase_tokens, html_action='return')
                    # with open(f"plots/attention_views/agg_model_view_{args.model_type}_example_{i}.html", 'w') as file:
                    #     file.write(html_view.data)

                    # get attention heads that pays attention to the right phrase
                    act_heads = get_correct_attention_heads(attn, data["span"], data["local_op_span_indices"])
                    act_heads_all = torch.any(act_heads, dim=0).int()

                    # act_all_container_train.append([a[0, -1, :].detach().cpu() for a in act])
                    # act_container_train.append(act[args.layer - 1][0, -1, :].detach().cpu())
                else:
                    # last hidden state
                    act = model(input_ids=ids, attention_mask=mask)[0, -1, :].detach().cpu()
                    act_container_train.append(act)

            for data in tqdm(loader_test, total=len(loader_test)):
                ids = data['prefix_ids'].to(inference_device, dtype=torch.long)
                mask = data['prefix_attn_masks'].to(inference_device, dtype=torch.long)
                if end_idx is not None:
                    ids = ids[0][:end_idx].unsqueeze(0)
                    mask = mask[0][:end_idx].unsqueeze(0)
                # last hidden state
                if args.save_model_representation:
                    output = model(input_ids=ids, attention_mask=mask, return_all_layers=True, output_attentions=True)
                    pdb.set_trace()
                    act_all_container_test.append([a[0, -1, :].detach().cpu() for a in act])
                    act_container_test.append(act[args.layer - 1][0, -1, :].detach().cpu())
                else:
                    # last hidden state
                    act = model(input_ids=ids, attention_mask=mask)[0, -1, :].detach().cpu()
                    act_container_test.append(act)
            torch.distributed.destroy_process_group()

        if args.save_model_representation:
            if not os.path.exists(args.model_representation_path):
                os.makedirs(args.model_representation_path, exist_ok=True)

            train_rep_path = os.path.join(args.model_representation_path, "representations_train.p")
            test_rep_path = os.path.join(args.model_representation_path, "representations_test.p")
            with open(train_rep_path, "wb") as rep_f:
                pickle.dump(act_all_container_train, rep_f)
                act_all_container_train.clear()
            with open(test_rep_path, "wb") as rep_f:
                pickle.dump(act_all_container_test, rep_f)
                act_all_container_test.clear()



if __name__ == "__main__":
    main()


