import os
import sys
with open(sys.argv[0]) as f:
    code = f.read() # read the code of this file ASAP, for logging
import uuid
import glob
import time
from dataclasses import dataclass
from einops import rearrange
from typing import Optional, Tuple
import math
import numbers
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP

class Rotary(torch.nn.Module):

    def __init__(self, dim, base=10000):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq).to(x.device)
            self.cos_cached = freqs.cos()
            self.sin_cached = freqs.sin()
        return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]

def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4 # multihead attention
    d = x.shape[3]//2
    x1 = x[..., :d]
    x2 = x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3).type_as(x)


def create_sinusoidal_embeddings(n_pos, dim, out):
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
    out[:, :, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, :, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False


def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device)
        img_token_num = L // 2
        temp_mask[:, :img_token_num] = True
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)
        attn_bias = attn_bias.to(query.device)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

def kendall_tau_distance(seq):
    n = len(seq)
    if n < 2:
        return 1.0
    total_pairs = float(n * (n - 1) // 2)
    discordant_pairs = 0.0
    for i in range(n):
        for j in range(i + 1, n):
            if seq[i] >= seq[j]:
                discordant_pairs += 1
   
    kendall_tau = 1.0 - 2 * discordant_pairs / total_pairs

    # print('n:', n, 'discordant_pairs:', discordant_pairs, 'total_pairs:', total_pairs, 'kendall_tau:', kendall_tau)
    return kendall_tau

def longest_increasing_subsequence(seq):
    if not seq:
        return 0
    lis = [1] * len(seq)
    for i in range(1, len(seq)):
        for j in range(i):
            if seq[i] > seq[j]:
                lis[i] = max(lis[i], lis[j] + 1)
    return max(lis)

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.emb_dim = config.emb_dim
        self.head_dim = self.emb_dim // self.n_head
        assert self.emb_dim % self.n_head == 0
        self.c_q = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
        self.c_k = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
        self.c_v = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
        # output projection
        self.c_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
        self.rotary = Rotary(self.head_dim)
        self.rms_norm = RMSNorm((self.head_dim,))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (emb_dim)
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
        cos, sin = self.rotary(q)
        q, k = self.rms_norm(q), self.rms_norm(k) # QK norm suggested by @Grad62304977
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        y = scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) 
        y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.emb_dim, 4 * config.emb_dim, bias=False)
        self.c_proj  = nn.Linear(4 * config.emb_dim, config.emb_dim, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)
        self.rms_norm = RMSNorm((config.emb_dim,))

    def forward(self, x):
        x = x + self.attn(self.rms_norm(x))
        x = x + self.mlp(self.rms_norm(x))
        return x
    
@dataclass
class OrderConfig:
    vocab_size : int = 32
    n_layer : int = 12
    n_head : int = 6
    emb_dim : int = 768
    image_size: int = 224
    sincos_pos_emb: bool = True
    regressive: bool = False
    clip_feature_input: bool = False
    feature_dim: int = 768
    feature_normalization: bool = True


def rms_norm(x, eps=1e-8):
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    return x / rms

class RMSNorm(nn.Module):
    __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
    normalized_shape: Tuple[int, ...]
    eps: Optional[float]
    elementwise_affine: bool

    def __init__(
        self,
        normalized_shape: Tuple[int, ...],
        eps: Optional[float] = 1e-8,
        elementwise_affine: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(
                torch.empty(self.normalized_shape, **factory_kwargs)
            )
        else:
            self.register_parameter("weight", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.elementwise_affine:
            nn.init.ones_(self.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = rms_norm(x, self.eps)
        if self.elementwise_affine:
            x = x * self.weight
        return x

    def extra_repr(self) -> str:
        return (
            "{normalized_shape}, eps={eps}, "
            "elementwise_affine={elementwise_affine}".format(**self.__dict__)
        )


class OrderPredictor(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        if config is None:
            config = OrderConfig()
        
        self.config = config

        if not config.clip_feature_input:
            self.feature_normalization = None
            self.image_to_embedding = nn.Linear(3 * config.image_size ** 2, config.emb_dim)
        else:
            if config.feature_normalization:
                self.feature_normalization = nn.LayerNorm(config.feature_dim)
            else:
                self.feature_normalization = None
            if config.feature_dim != config.emb_dim:
            # if 1:
                self.image_to_embedding = nn.Linear(config.feature_dim, config.emb_dim)
            else:
                self.image_to_embedding = nn.Identity()

        # positional embedding
        if config.sincos_pos_emb:
            self.pos_embedding = nn.Parameter(torch.zeros(1, config.vocab_size * 2, config.emb_dim))
            create_sinusoidal_embeddings(config.vocab_size * 2, config.emb_dim, self.pos_embedding.data)
        else:
            self.pos_embedding = nn.Parameter(torch.zeros(1, config.vocab_size * 2, config.emb_dim))
            self.pos_embedding.data.normal_(mean=0.0, std=0.02)

        self.transformer = nn.ModuleDict(dict(
            embed = nn.Embedding(config.vocab_size * 2, config.emb_dim),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ))
        if not config.regressive:
            self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
        else:
            self.lm_head = nn.Linear(config.emb_dim, 1, bias=False)
        self.lm_head.weight.data.zero_()
        self.pre_rms_norm = RMSNorm(config.emb_dim)
        self.post_rms_norm = RMSNorm(config.emb_dim)

        self.total_params = sum(p.numel() for p in self.parameters())
        print(f"Total parameters: {self.total_params}")

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embedding"}

    def forward(self, image):
        b, t = image.size()[:2]
        assert t == self.config.vocab_size, f"image tensor must have {self.config.vocab_size} time steps"
        x = image
        original_indices = torch.arange(t).expand(b, t).to(x.device)

        x = x.view(b, t, -1)

        if self.feature_normalization is not None:
            x = self.feature_normalization(x)
        x = self.image_to_embedding(x)
        x = self.pre_rms_norm(x)
        ori_x = x

        sorted_logits = torch.Tensor().to(x.device)
        sorted_indices = torch.Tensor().to(x.device).type(torch.int64)

        for output_num in range(self.config.vocab_size + 1):
            y = x + self.pos_embedding[:, :self.config.vocab_size + output_num]
            for block in self.transformer.h:
                y = block(y)
            y = self.post_rms_norm(y)
            if output_num < self.config.vocab_size:
                predicted_logits = self.lm_head(y[:, -1])
                sorted_logits = torch.cat((sorted_logits, predicted_logits.unsqueeze(1)), dim=1)
                predicted_indices = torch.argmax(predicted_logits, dim=-1)
                sorted_indices = torch.cat((sorted_indices, predicted_indices.unsqueeze(1)), dim=1)
                x = torch.cat((x, self.transformer.embed(predicted_indices).unsqueeze(1)), dim=1)

        order_tokens = self.transformer.embed(original_indices)
        x = torch.cat((ori_x, order_tokens), dim=1)
        x = x + self.pos_embedding
        for block in self.transformer.h:
            x = block(x)
        x = self.post_rms_norm(x)

        sorted_logits = sorted_logits.float()

        lis_lengths_stepwise = torch.zeros_like(sorted_indices)
        kendall_tau_distances = torch.zeros_like(sorted_indices).float()
        for i in range(b):
            for j in range(self.config.vocab_size):
                lis_lengths_stepwise[i, j] = longest_increasing_subsequence(sorted_indices[i, :j+1].tolist())
                kendall_tau_distances[i, j] = kendall_tau_distance(sorted_indices[i, :j+1].tolist())

        return sorted_logits, kendall_tau_distances, lis_lengths_stepwise
        
    
    def forward_train(self, image, keep_order=True, get_score=True, use_mask=True):
        b, t = image.size()[:2]
        assert t == self.config.vocab_size, f"image tensor must have {self.config.vocab_size} time steps"
        x = image
        shuffled_x = torch.empty_like(x)
        original_indices = torch.empty(b, t, dtype=torch.long).to(x.device)
        if not keep_order:
            for i in range(b):
                random_idx = torch.randperm(t)
                shuffled_x[i] = x[i, random_idx]
                original_indices[i] = torch.argsort(random_idx)
        else:
            for i in range(b):
                original_indices[i] = torch.arange(t)
            shuffled_x = x

        x = shuffled_x.view(b, t, -1)
        x = self.image_to_embedding(x)
        x = self.pre_rms_norm(x)

        # [BOS version] bos_tokens = self.cls_token.expand(b, -1, -1)
        order_tokens = self.transformer.embed(original_indices)

        # [BOS version] x = torch.cat((bos_tokens, x, order_tokens), dim=1)
        x = torch.cat((x, order_tokens), dim=1)

        # add positional embedding
        x = x + self.pos_embedding

        for block in self.transformer.h:
            x = block(x)
        x = self.post_rms_norm(x)

        targets = original_indices

        logits = self.lm_head(x[:, self.config.vocab_size - 1 :-1])
        # logits = self.lm_head(x)
        # [BOS version] logits = self.lm_head(x[:, self.config.vocab_size:-1])
        logits = logits.float()
        b, t, _ = logits.size()
        
        sorted_logits = torch.zeros_like(logits)

        for i in range(b):
            sorted_logits[i] = logits[i, targets[i].argsort()]

        if not self.config.regressive:
            _, predicted_indices = torch.max(logits, dim=-1)

            sorted_predicted_indices_nomask = torch.zeros_like(predicted_indices)
            lis_lengths_stepwise = torch.zeros_like(predicted_indices).float()
            lis_lengths_nomask = []
            for i in range(b):
                sorted_predicted_indices_nomask[i] = predicted_indices[i, targets[i].argsort()]
                lis_length_nomask = longest_increasing_subsequence(sorted_predicted_indices_nomask[i].tolist())
                lis_lengths_nomask.append(lis_length_nomask)
                for j in range(t):
                    lis_lengths_stepwise[i, j] = longest_increasing_subsequence(sorted_predicted_indices_nomask[i, :j+1].tolist())
            if get_score:
                return sorted_logits, sorted_predicted_indices_nomask, lis_lengths_nomask, lis_lengths_stepwise
        else:
            if get_score:
                score = torch.sigmoid(sorted_logits)
                lis_lengths_stepwise = torch.zeros_like(score)
                lis_lengths_nomask = []
                for i in range (b):
                    lis_length_nomask = longest_increasing_subsequence(score[i].tolist())
                    lis_lengths_nomask.append(lis_length_nomask)
                    for j in range(t):
                        lis_lengths_stepwise[i, j] = longest_increasing_subsequence(score[i, :j+1].tolist())
                return sorted_logits, score, lis_lengths_nomask, lis_lengths_stepwise

        if self.config.regressive:
            targets = targets.float() / (self.config.vocab_size)
            logits = logits.squeeze(-1)
            logits = torch.sigmoid(logits)
            loss = F.mse_loss(logits, targets)
            return loss, score, lis_lengths_nomask, lis_lengths_stepwise, logits, sorted_logits

        else:
            # logits = torch.sigmoid(logits)
            average_lis_length_nomask = (sum(lis_lengths_nomask) / len(lis_lengths_nomask)) / (self.config.vocab_size)
            targets_onehot = F.one_hot(targets, self.config.vocab_size).int()
            if use_mask:
                used_targets = torch.cumsum(targets_onehot, dim=1) > 0
                used_targets = torch.roll(used_targets, 1, 1)
                used_targets[:, 0] = 0
                masked_logits_all = torch.where(used_targets, logits.new_tensor(-1e9), logits)

                loss = F.cross_entropy(masked_logits_all.view(-1, self.config.vocab_size), targets.view(-1), ignore_index=-1)
            else:
                loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1))
                
        return loss, sorted_predicted_indices_nomask, lis_lengths_nomask, average_lis_length_nomask, logits, sorted_logits