import torch
import numpy as np
import random

from dataclasses import dataclass, field
from typing import Dict, Any, List


@dataclass
class DataCollator:
    eos_token: int
    append_bos: bool

    def __call__(self, batch) -> Dict[str, Any]:
        batch = torch.cat(batch, dim=0)
        sz = batch.size()
        
        if not self.append_bos:
            return {
                "input_ids": batch[:, :-1],
                "labels": batch[:, 1:],
            }

        bos_tokens = batch.new_full((sz[0], 1), self.eos_token)
        input_ids = torch.cat((bos_tokens, batch[:, :-1]), dim=-1)

        data = {
            "input_ids": input_ids,
            "labels": batch,
            "standard":{
                "input_ids": input_ids,
                "labels": batch,
            }
        }
        return data

@dataclass
class SparseDataCollator:
    # Datacollator for Sparse Training
    eos_token: int
    append_bos: bool
    sample_len: int
    
    def __call__(self, batch) -> Dict[str, Any]:
        batch = torch.cat(batch, dim=0)
        sz = batch.size()

        if self.append_bos:
            bos_tokens = batch.new_full((sz[0], 1), self.eos_token)
            batch = torch.cat((bos_tokens, batch), dim=-1)

        labels_len = self.sample_len // 2
        seq_len = batch.size(1) - 1

        position_ids = np.arange(seq_len)
        positions = torch.cat([
            posteriori_sampling_llama2(position_ids, self.sample_len, start_tokens=0).unsqueeze(0)
            for i in range(len(batch))
        ], dim=0)

        labels = batch[:, -labels_len:]
        input_ids = reshape_inputs(batch, positions)

        # for mixed training

        input_ids1 = batch[:, :self.sample_len]
        labels1 = batch[:, 1:self.sample_len + 1]

        data = {
            "input_ids": input_ids,
            "labels": labels,
            "position_ids": positions,
            "standard":{
                "input_ids": input_ids1,
                "labels": labels1
            }
        }
        return data

# For GPT2, we extend length from 1024 to 2048
def posteriori_sampling_gpt2(
    position_ids: np.array,
    target_len: int,
    start_tokens: int = 0,
    unseen_position_ids: int = 1024,
): 
    """

    Args:
        position_ids (np.array): ids array encompass all positions ids;
        target_len (int): the determined sequence length, composed of predicted part and sampled part;
        segment_len
        start_tokens: whether to keep the first few tokens to avoid attention sink;
    """
    end_len = target_len // 2

    # Keep some initial tokens to avoid attention sink.
    start_ids = np.arange(start_tokens)

    block_idx = np.random.choice(position_ids[unseen_position_ids::end_len], 1)[0]

    # Keep last few tokens for predictions, this part is similar to standard next token prediction.
    end_ids = position_ids[block_idx:block_idx + end_len]

    # the remaining part are sampled.
    sampled_ids = np.random.choice(position_ids[start_tokens:block_idx], end_len - len(start_ids), replace=False)
    sampled_ids.sort()

    final_ids = np.concatenate((start_ids, sampled_ids, end_ids))
    return torch.from_numpy(final_ids)


# For LLama 2, we extend length from 4096 to 8192
def posteriori_sampling_llama2(
    position_ids: np.array,
    target_len: int,
    start_tokens: int = 0,
    unseen_position_ids: int = 4096,
    segment_len: int = 8,
):  
    end_len = target_len // 2

    # Keep some initial tokens to avoid attention sink.
    start_ids = np.arange(start_tokens)

    #end_tokens = np.random.choice(position_ids[unseen_position_ids:-end_len], 1)[0]
    end_tokens = len(position_ids) - end_len

    # Keep last few tokens for predictions, this part is similar to standard next token prediction.
    end_ids = position_ids[end_tokens:end_tokens + end_len]
    
    # the remaining part are sampled.
    n_samples = end_len - start_tokens
    n_segments = n_samples // segment_len
    #sampled_ids = np.random.choice(position_ids[start_tokens:block_idx], n_samples, replace=False)
    #sampled_ids.sort()
    sampled_segments = np.random.choice(position_ids[start_tokens:end_tokens:segment_len], n_segments, replace=False)
    sampled_segments.sort()
    sampled_ids = np.concatenate([np.arange(i, i + segment_len) for i in sampled_segments])

    final_ids = np.concatenate((start_ids, sampled_ids, end_ids))
    return torch.from_numpy(final_ids)

def reshape_inputs(inputs: torch.Tensor, positions: torch.Tensor):
    sz = inputs.size()
    offset = torch.arange(0, sz[0] * sz[1], sz[1])
    index = positions + offset.unsqueeze_(1)
    return inputs.reshape(-1)[index]
