import torch
import random
import numpy as np
from utils import set_seed
from typing import Dict, Optional, Sequence
from dataclasses import dataclass
import math

IGNORE_INDEX = -100

def generate_batch_randperm(n_batch, n_pos):
    rand = torch.rand(n_batch, n_pos)
    batch_rand_perm = rand.argsort(dim=1)
    return batch_rand_perm


def prepare_data(setting, n_train, n_test, n_context, n_dim, seed_data, *args, **kwargs):
    set_seed(seed_data)
    if setting == "cnn":
        # one data point: 1 one hot feature and 1 gaussian noist
        train_x, train_y, test_x, test_y = generate_data_cnn(n_train, n_test, n_context, n_dim)
    elif setting == "shallow":
        # one data point: alpha_r % feature from codebook, alpha_c percent confused feature, remained are other features.
        # all tokens have additive gaussian noise
        train_x, train_y, test_x, test_y = generate_data_shallow(n_train, n_test, n_context, n_dim, *args, **kwargs)
    elif setting == "spatial":
        # one data point: alpha_r % one hot feature,others have no features
        # all tokens have additive orthogonal gaussian noise
        train_x, train_y, test_x, test_y = generate_data_spatial(n_train, n_test, n_context, n_dim, *args, **kwargs)
    elif setting == "one_direction":
        # one data point: alpha_r % one hot feature patch, others are noise patches
        # tokens consist of feature patches and noise patches
        train_x, train_y, test_x, test_y = generate_data_one_direction(n_train, n_test, n_context, n_dim, *args, **kwargs)
    elif setting == "one_direction_f":
        # one data point: feature_ratio % one hot feature patch, others are noise patches
        # tokens consist of feature patches and noise patches
        train_x, train_y, test_x, test_y = generate_data_one_direction_scale_feature(n_train, n_test, n_context, n_dim, *args, **kwargs)
    elif setting == "one_direction_sparse":
        # one data point: feature_ratio % one hot feature patch, others are noise patches
        # tokens consist of feature patches and noise patches
        # in each noise patch/token, only sparsity_level coordinates are nonzero
        train_x, train_y, test_x, test_y = generate_data_one_direction_sparse(n_train, n_test, n_context, n_dim, *args, **kwargs)
    elif setting == "multiclass_type1":
        # one data point: one hot feature & noise (when patch=True, they are in different patches; when patch=False, they are added)
        # the one hot feature is determined by the label, where label are generated by the weights
        # Therefore, this represents the class imbalance led by p(x) shift
        train_x, train_y, test_x, test_y = generate_data_multiclass_type1(n_train, n_test, n_context, n_dim, *args, **kwargs)
    elif setting == "multiclass_noise":
        # one data point: just standard gaussian noise
        # label partition is given by weights
        train_x, train_y, test_x, test_y = generate_data_multiclass_noise(n_train, n_test, n_context, n_dim, *args, **kwargs)
    elif setting == "flipflop":
        '''
        Exposing Attention Glitches with Flip-Flop Language Modeling
        https://arxiv.org/abs/2306.00946

        '''
        # TODO: focus on casual language modeling
        data = generate_data_flipflop(n_train, n_test, n_context, n_dim, *args, **kwargs)
        return data


    # filp train data labels
    n_flip = int(n_train * 0.00)
    flip_indices_train = torch.randperm(n_train)[:n_flip]
    train_y[flip_indices_train] = -train_y[flip_indices_train]

    # Determine the number of labels to flip
    # n_flip = int(n_test * 0.10)

    # Randomly choose indices of labels to flip
    # flip_indices = torch.randperm(n_test)[:n_flip]

    # Flip the labels at the chosen indices
    # test_y[flip_indices] = -test_y[flip_indices]

    return train_x, train_y, test_x, test_y, flip_indices_train


def generate_data_cnn(n_train, n_test, n_context, n_dim, *args, **kwargs):
    print(f"generate cnn data")
    def gen(B, T, D):
        labels = torch.cat((torch.ones(int(B/2)), -torch.ones(int(B/2))))
        feature = torch.zeros(D, 1)
        feature[0] = 1.

        num_positive_samples = int(B/2)
        num_feature_tokens = int(T/2)
        data = torch.zeros(B, T, D)
        data[:num_positive_samples, :num_feature_tokens] = feature
        data[num_positive_samples:, :num_feature_tokens] = -feature
        data[:, num_feature_tokens:] = torch.randn(B, T-num_feature_tokens, D)

        # shuffle accorss tokens in each data point
        # batch_randperm = generate_batch_randperm(B, 2)
        # expanded_idx = batch_randperm.unsqueeze(-1).expand(-1, -1, D)
        # data = torch.gather(data, 1, expanded_idx)

        return data, labels
    
    train_x, train_y = gen(n_train, n_context, n_dim)
    test_x, test_y = gen(n_test, n_context, n_dim)
    
    return train_x, train_y, test_x, test_y


def generate_data_shallow(n_train, n_test, n_context, n_dim, n_vocab, alpha_r, alpha_c, noise_level, *args, **kwargs):
    print(f"generate shallow data")
    # Predefined k vectors v[1], ..., v[k]
    # vectors = torch.randn(n_vocab, n_dim)  # Assume each v[i] is a D-dimensional vector
    vectors = torch.eye(n_dim)[:n_vocab] # M, D
    
    def gen(B, T, D, M, alpha_r, alpha_c):
        # Step 1: Generate data labels
        # labels = torch.randint(2, (B,)) * 2 - 1  # Generate B labels, each is either +1 or -1
        num_positive_samples = int(B/2)
        labels = torch.cat((torch.ones(num_positive_samples), -torch.ones(num_positive_samples)))
        # Initialize dataset
        data = torch.zeros((B, T, D))
        num_relevant_tokens = int(T * alpha_r)
        num_confused_tokens = int(T * alpha_c)
        num_nd_tokens = T - num_relevant_tokens - num_confused_tokens

        # fill positive samples
        data[:num_positive_samples, :num_relevant_tokens] = vectors[0]
        data[:num_positive_samples, num_relevant_tokens:num_relevant_tokens+num_confused_tokens] = vectors[1]
        # data[:num_positive_samples, remaining_tokens] = TODO

        # fill negative samples
        data[num_positive_samples:, :num_relevant_tokens] = vectors[1]
        data[num_positive_samples:, num_relevant_tokens:num_relevant_tokens+num_confused_tokens] = vectors[0]
        # data[num_positive_samples:, remaining_tokens] = TODO

        # fill non-discriminative part
        uniform_indices = torch.randint(2, n_vocab, (B, num_nd_tokens))
        data[:, -num_nd_tokens:] = vectors[uniform_indices]
        data = data + torch.randn_like(data) * noise_level

        # shuffle accorss tokens in each data point
        # batch_randperm = generate_batch_randperm(B, T)
        # expanded_idx = batch_randperm.unsqueeze(-1).expand(-1, -1, D)
        # data = torch.gather(data, 1, expanded_idx)

        return data, labels

    train_x, train_y = gen(n_train, n_context, n_dim, n_vocab, alpha_r, alpha_c)
    test_x, test_y = gen(n_test, n_context, n_dim, n_vocab, alpha_r, alpha_c)
    
    return train_x, train_y, test_x, test_y


def generate_data_spatial(n_train, n_test, n_context, n_dim, alpha_r, proba_nonzero, noise_level, *args, **kwargs):
    print(f"generate spatial data")
    def gen(B, T, D):
        num_positive_samples = int(B/2)
        labels = torch.cat((torch.ones(num_positive_samples), -torch.ones(num_positive_samples)))
        # Initialize dataset
        data = torch.zeros((B, T, D))
        num_relevant_tokens = int(T * alpha_r)
        num_nd_tokens = T - num_relevant_tokens
        
        feature = torch.zeros(D)
        feature[0] = 1.
        # fill positive samples
        data[:num_positive_samples, :num_relevant_tokens] = feature
        # data[:num_positive_samples, remaining_tokens] = TODO

        # fill negative samples
        data[num_positive_samples:, :num_relevant_tokens] = -feature
        # data[num_positive_samples:, remaining_tokens] = TODO

        # fill non-discriminative part
        feature_noise = torch.randn(B, num_nd_tokens, D).sign()
        nonzero = torch.bernoulli(torch.ones(B, num_nd_tokens, D) * proba_nonzero)
        feature_noise = feature_noise * nonzero
        data[:, num_relevant_tokens:] = feature_noise # B, nd, D

        orth_gaussian_noise = torch.randn_like(data) * noise_level
        orth_gaussian_noise[:, :, 0] = 0.
        data = data + orth_gaussian_noise

        # shuffle accorss tokens in each data point
        # batch_randperm = generate_batch_randperm(B, T)
        # expanded_idx = batch_randperm.unsqueeze(-1).expand(-1, -1, D)
        # data = torch.gather(data, 1, expanded_idx)

        return data, labels
    train_x, train_y = gen(n_train, n_context, n_dim)
    test_x, test_y = gen(n_test, n_context, n_dim)
    
    return train_x, train_y, test_x, test_y


def generate_data_one_direction(n_train, n_test, n_context, n_dim, positive_ratio, feature_ratio, noise_level, orth_noise, *args, **kwargs):
    print(f"generate one direction data")
    def gen(B, T, D, positive_ratio):
        num_positive_samples = int(B * positive_ratio)
        num_negative_samples = B - num_positive_samples
        labels = torch.cat((torch.ones(num_positive_samples), -torch.ones(num_negative_samples)))
        # Initialize dataset
        data = torch.zeros((B, T, D))
        num_feature_tokens = int(T * feature_ratio)
        num_gn_tokens = T - num_feature_tokens
        
        feature = torch.zeros(D)
        feature[0] = 1.
        # fill positive samples
        data[:num_positive_samples, :num_feature_tokens] = feature
        # data[:num_positive_samples, remaining_tokens] = TODO

        # fill negative samples
        data[num_positive_samples:, :num_feature_tokens] = -feature
        # data[num_positive_samples:, remaining_tokens] = TODO

        # fill non-discriminative part
        # no feature noise

        gaussian_noise = torch.randn(B, num_gn_tokens, D) * noise_level
        if orth_noise:
            gaussian_noise[:, :, 0] = 0.

        data[:, num_feature_tokens:] = gaussian_noise

        # shuffle accorss tokens in each data point
        # batch_randperm = generate_batch_randperm(B, T)
        # expanded_idx = batch_randperm.unsqueeze(-1).expand(-1, -1, D)
        # data = torch.gather(data, 1, expanded_idx)

        return data, labels
    train_x, train_y = gen(n_train, n_context, n_dim, positive_ratio)
    test_x, test_y = gen(n_test, n_context, n_dim, 0.5)
    
    return train_x, train_y, test_x, test_y


def generate_data_one_direction_scale_feature(n_train, n_test, n_context, n_dim, positive_ratio, feature_ratio, snr, noise_level=1.0, *args, **kwargs):
    print(f"generate one direction data (scale feature)")
    def gen(B, T, D, positive_ratio):
        num_positive_samples = int(B * positive_ratio)
        num_negative_samples = B - num_positive_samples
        labels = torch.cat((torch.ones(num_positive_samples), -torch.ones(num_negative_samples)))
        # Initialize dataset
        data = torch.zeros((B, T, D))
        num_feature_tokens = int(T * feature_ratio)
        num_gn_tokens = T - num_feature_tokens
        
        feature = torch.zeros(D)
        feature[0] = 1. * math.sqrt(n_dim) * snr * noise_level
        # fill positive samples
        data[:num_positive_samples, :num_feature_tokens] = feature
        # data[:num_positive_samples, remaining_tokens] = TODO

        # fill negative samples
        data[num_positive_samples:, :num_feature_tokens] = -feature
        # data[num_positive_samples:, remaining_tokens] = TODO

        # fill non-discriminative part
        # no feature noise

        gaussian_noise = torch.randn(B, num_gn_tokens, D) * noise_level
        gaussian_noise[:, :, 0] = 0.

        data[:, num_feature_tokens:] = gaussian_noise

        # shuffle accorss tokens in each data point
        # batch_randperm = generate_batch_randperm(B, T)
        # expanded_idx = batch_randperm.unsqueeze(-1).expand(-1, -1, D)
        # data = torch.gather(data, 1, expanded_idx)

        return data, labels
    train_x, train_y = gen(n_train, n_context, n_dim, positive_ratio)
    test_x, test_y = gen(n_test, n_context, n_dim, 0.5)
    
    return train_x, train_y, test_x, test_y


def generate_data_one_direction_sparse(n_train, n_test, n_context, n_dim, 
                                       positive_ratio, feature_ratio, snr, 
                                       noise_level, sparsity_level, ortho, overlap, *args, **kwargs):
    print(f"generate one direction sparse data")
    def gen(B, T, D, positive_ratio, overlap):
        num_positive_samples = int(B * positive_ratio)
        num_negative_samples = B - num_positive_samples
        labels = torch.cat((torch.ones(num_positive_samples), -torch.ones(num_negative_samples)))
        # Initialize dataset
        data = torch.zeros((B, T, D))
        num_feature_tokens = int(T * feature_ratio)
        num_gn_tokens = T - num_feature_tokens
        
        feature = torch.zeros(D)
        feature[0] = 1. * math.sqrt(sparsity_level) * snr * noise_level
        # fill positive samples
        data[:num_positive_samples, :num_feature_tokens] = feature
        # data[:num_positive_samples, remaining_tokens] = TODO

        # fill negative samples
        data[num_positive_samples:, :num_feature_tokens] = -feature
        # data[num_positive_samples:, remaining_tokens] = TODO

        # fill non-discriminative part
        # no feature noise

        gaussian_noise = torch.randn(B, num_gn_tokens, D) * noise_level
        if sparsity_level is None or sparsity_level == D:
            if ortho:
                gaussian_noise[:, :, 0] = 0.
            data[:, num_feature_tokens:] = gaussian_noise
        else:
            if not overlap:
                print("nonoverlap sparsity")
                assert sparsity_level * B * num_gn_tokens < D, "nonoverlap sparisty requires the inequality s * B * N_p < D holds."
                if ortho:
                    perm = (torch.arange(sparsity_level * B * num_gn_tokens) + 1) % D
                    assert (perm != 0).all()
                else:
                    perm = (torch.arange(sparsity_level * B * num_gn_tokens)) % D
                perm = perm.view(B, num_gn_tokens, sparsity_level)
                data[:, num_feature_tokens:].scatter_(dim=2, index=perm, src=gaussian_noise)
            else:
                print("overlap sparsity")
                if ortho:
                    perm = generate_batch_randperm(B * num_gn_tokens, D - 1)[:, :sparsity_level]
                    perm = perm + 1
                    assert (perm != 0).all()
                else:
                    perm = generate_batch_randperm(B * num_gn_tokens, D)[:, :sparsity_level]
                perm = perm.view(B, num_gn_tokens, sparsity_level)
                data[:, num_feature_tokens:].scatter_(dim=2, index=perm, src=gaussian_noise)

        # shuffle accorss tokens in each data point
        # batch_randperm = generate_batch_randperm(B, T)
        # expanded_idx = batch_randperm.unsqueeze(-1).expand(-1, -1, D)
        # data = torch.gather(data, 1, expanded_idx)

        return data, labels
    train_x, train_y = gen(n_train, n_context, n_dim, positive_ratio, overlap=overlap)
    test_x, test_y = gen(n_test, n_context, n_dim, 0.5, overlap=True)
    
    return train_x, train_y, test_x, test_y


def sample_classes_from_weights(num_samples, num_classes, weights):
    if weights is None:
        probabilities = torch.ones(num_classes) / num_classes
    elif isinstance(weights, str):
        if weights == "power":
            probabilities = torch.as_tensor([1 / (np.exp(1) + i) for i in range(num_classes)])
            probabilities /= probabilities.sum()  # normalize to sum to 1
        elif weights == "exponential":
            # probabilities = torch.as_tensor([ 2 ** (-i) for i in range(num_classes)])
            # probabilities /= probabilities.sum()  # normalize to sum to 1
            labels_fixlen = [k for k in range(num_classes) for _ in range(2**k)]
            if num_samples <= len(labels_fixlen):
                labels = torch.as_tensor(labels_fixlen)[:num_samples]
            else:
                labels = torch.zeros((num_samples,), dtype=torch.int64)
                baselen = len(labels_fixlen)
                labels[:baselen] = torch.as_tensor(labels_fixlen)
                labels[baselen:] = num_classes - 1
            
            return labels
        else:
            raise NotImplementedError    
    elif isinstance(weights, torch.Tensor):
        probabilities = weights / weights.sum()
    else:
        raise NotImplementedError
    
    labels = torch.multinomial(probabilities, num_samples, replacement=True)
    labels_sorted = labels.sort()[0]
    return labels_sorted
    

def generate_data_multiclass_type1(n_train, n_test, n_context, n_dim, alpha_f, noise_level, orth_noise, patched, class_num, weights=None, *args, **kwargs):
    print(f"generate multiclass data type 1")
    assert n_dim >= class_num

    def gen(B, T, D, weights):
        # when balanced=False, if we have k classes, each class has 2**k samples
        # when k=10, B=1+...+512=1023
        features = torch.eye(D)[:class_num]

        # generate labels
        labels = sample_classes_from_weights(B, class_num, weights)

        # Initialize dataset
        data = torch.zeros((B, T, D))
        num_feature_tokens = int(T * alpha_f)
        num_gn_tokens = T - num_feature_tokens
        
        # fill features
        data = features[labels]
        if patched:
            data = data.unsqueeze(1).expand(-1, T, -1).clone()
            data[:, num_feature_tokens:] = 0
        
        # fill non-discriminative part
        # no feature noise
        if patched:
            gaussian_noise = torch.randn(B, num_gn_tokens, D) * noise_level
            if orth_noise:
                gaussian_noise[:, :, :class_num] = 0.
            data[:, num_feature_tokens:] = gaussian_noise
        else:
            gaussian_noise = torch.randn(B, D) * noise_level
            if orth_noise:
                gaussian_noise[:, :class_num] = 0.
            data += gaussian_noise

        # shuffle accorss tokens in each data point
        # batch_randperm = generate_batch_randperm(B, T)
        # expanded_idx = batch_randperm.unsqueeze(-1).expand(-1, -1, D)
        # data = torch.gather(data, 1, expanded_idx)

        return data, labels
    
    train_x, train_y = gen(n_train, n_context, n_dim, weights=weights)
    test_x, test_y = gen(n_test, n_context, n_dim, weights=None)
    
    return train_x, train_y, test_x, test_y


def generate_data_multiclass_noise(n_train, n_test, n_context, n_dim, noise_level, patched, class_num, weights=None, *args, **kwargs):
    print(f"generate multiclass noise data")
    assert n_dim >= class_num and n_dim >= n_train
    assert not patched

    def gen(B, D, weights):
        # generate labels
        labels = sample_classes_from_weights(B, class_num, weights)
        data = torch.randn(B, D) * noise_level

        # shuffle accorss tokens in each data point
        # batch_randperm = generate_batch_randperm(B, T)
        # expanded_idx = batch_randperm.unsqueeze(-1).expand(-1, -1, D)
        # data = torch.gather(data, 1, expanded_idx)

        return data, labels
    
    train_x, train_y = gen(n_train, n_dim, weights=weights)
    test_x, test_y = gen(n_test, n_dim, weights=None)
    
    return train_x, train_y, test_x, test_y


def generate_data_flipflop(
        n_train, n_test, n_context, n_dim, 
        path=None, 
        cache_dir="/root/autodl-tmp/huggingface", 
        *args, 
        **kwargs,
    ):
    from datasets import load_dataset
    from transformers import AutoTokenizer, AutoModelForCausalLM
    tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/hf/models/gpt2")

    if path is None:
        path = "/root/autodl-tmp/hf/datasets/synthseq/flipflop"
    train_files = [
        f"{path}/data/train-00000-of-00002-b4ca324082d96883.parquet",
        f"{path}/data/train-00001-of-00002-bf5f777704418c83.parquet",
    ]
    data_files = {
        "train": train_files, 
        "val": f"{path}/data/val-00000-of-00001-fec0c03d88b56508.parquet",
        "val_dense": f"{path}/data/val_dense-00000-of-00001-cd57636c8e1ff5a5.parquet",
        "val_sparse": f"{path}/data/val_sparse-00000-of-00001-1fcf65938d0f40dd.parquet",
    }
    raw_datasets = load_dataset("parquet", data_files=data_files, cache_dir=cache_dir)
    raw_datasets["train"] = raw_datasets["train"].select(range(n_train))
    raw_datasets["val"] = raw_datasets["val"].select(range(n_test))
    raw_datasets["val_dense"] = raw_datasets["val_dense"].select(range(n_test))
    raw_datasets["val_sparse"] = raw_datasets["val_sparse"].select(range(n_test))

    mapping = {
        15: 0, # 0
        16: 1, # 1
        72: 2, # i
        86: 3, # w
        81: 4, # r
    }

    def preprocess_function(examples):
        # Tokenize the texts
        texts = (
            (examples["text"],) 
        )
        result = tokenizer(*texts, return_tensors="pt")

        input_ids = result["input_ids"][..., :n_context].clone()
        for original, new in mapping.items():
            # embed
            input_ids = torch.where(input_ids == original, torch.tensor(new), input_ids)
        result["input_ids"] = input_ids

        return result

    # with accelerator.main_process_first():
    processed_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_datasets["train"].column_names,
        desc="Running tokenizer on dataset",
    )
    print(processed_datasets)
    # print(processed_datasets["val_dense"][:5]["input_ids"])
    # print(processed_datasets["val"][:5]["input_ids"])
    # for key in ["input_ids", "labels", "tensor"]:
    #     print(key, type(processed_datasets["train"][key]))
    # data = {}
    # data["train"] = tokenizer(raw_datasets["train"][:n_train]["text"], return_tensors="pt")
    # data["val"] = tokenizer(raw_datasets["val"][:n_test]["text"], return_tensors="pt")
    # data["val_dense"] = tokenizer(raw_datasets["val_dense"][:n_test]["text"], return_tensors="pt")
    # data["val_sparse"] = tokenizer(raw_datasets["val_sparse"][:n_test]["text"], return_tensors="pt")
    # identity = torch.eye(n_dim)
    # for split in data:
    #     input_ids = data[split]["input_ids"] = data[split]["input_ids"][:, :n_context] # (batch, seqlen)
    #     for original, new in mapping.items():
    #         # embed
    #         input_ids = torch.where(input_ids == original, torch.tensor(new), input_ids)
    #     tensor = identity[input_ids].clone() # (batch, seqlen, dim)
    #     data[split]["tensor"] = tensor
    #     # labels
    #     labels = input_ids.clone()
    #     # Shift the tensor to compare each element with its predecessor
    #     T_shifted = torch.roll(labels, shifts=1, dims=1)
    #     condition_indices = (T_shifted == 4)
    #     not_condition_indices = ~condition_indices
    #     labels[not_condition_indices] = IGNORE_INDEX
    #     data[split]["labels"] = labels

    return processed_datasets


@dataclass
class DataCollatorForFlipflop(object):
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids = tuple([instance[key] for instance in instances] for key in ["input_ids"])
        input_ids = torch.as_tensor(input_ids[0])

        # labels
        labels = input_ids.clone()
        T_shifted = torch.roll(labels, shifts=1, dims=1)
        condition_indices = (T_shifted == 4)
        not_condition_indices = ~condition_indices
        labels[not_condition_indices] = IGNORE_INDEX

        num_labels = (labels != -100).float().sum(dim=-1)

        return dict(
            input_ids=input_ids,
            labels=labels,
            num_labels=num_labels,
        )
