import random
import torch
from torch import nn
from transformers import MT5Tokenizer, MT5EncoderModel, AutoTokenizer, PreTrainedTokenizerFast

import time
import numpy as np
import gc
import torch.nn.functional as F
import regex as re
try:
    from utils.icu_sentence_split import split_sentence_icu as split_fn
except:
    from utils.blingfire_sentence_split import split_sentence_blingfire as split_fn



DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "google/mt5-base"  # small/large/xl if you like

def matching(C):
    """
    Monotone (non-decreasing) many-to-one matching: every a_i is assigned to some b_j,
    j can repeat, and order is preserved (j_1 <= j_2 <= ... <= j_n).
    Args:
        C: (n, m) cost matrix where C[i,j] is cost of matching a_i to b_j.
    Returns:
        total_cost: float
        pairs:      list of (i, j) in increasing i (0-based indices)
        pair_costs: list of C[i, j] aligned with pairs
    """
    n, m = C.shape
    if n == 0: 
        return 0.0, [], []
    if m == 0:
        raise ValueError("Infeasible: |B|=0 but all A must be matched.")

    # dp[i,j] is not stored fully; we keep only the current row and backpointers.
    ptr = torch.empty((n, m), dtype=torch.int32)  # previous j* for dp[i,j]; -1 marks start at i=0

    # Base row (i=0): choose any j
    dp_cur = C[0].clone()

    ptr[0, :] = -1

    for i in range(1, n):
        prev = dp_cur
        # prefix minima over prev[0..j-1] to allow jumps from any k<j
        prefix_min = torch.empty(m, dtype=torch.float64)
        prefix_arg = torch.empty(m, dtype=torch.int32)
        prefix_min[0] = np.inf
        prefix_arg[0] = -1
        best = prev[0]; best_idx = 0
        for j in range(1, m):
            if prev[j-1] < best:
                best = prev[j-1]; best_idx = j-1
            prefix_min[j] = best
            prefix_arg[j] = best_idx

        dp_next = torch.empty(m, dtype=torch.float64)
        for j in range(m):
            stay = prev[j]                 # match a_i to the SAME b_j (many-to-one)
            inc  = prefix_min[j]           # match a_i to a larger b_j (from best k<j)
            if stay <= inc:                # tie-break: prefer staying (fewer jumps)
                dp_next[j] = C[i, j] + stay
                ptr[i, j] = j
            else:
                dp_next[j] = C[i, j] + inc
                ptr[i, j] = prefix_arg[j]
        dp_cur = dp_next

    # Pick best terminal j and backtrack
    j = int(np.argmin(dp_cur))
    total_cost = float(dp_cur[j])

    pairs, pair_costs = [], []
    for i in range(n-1, -1, -1):
        pairs.append((i, j))
        pair_costs.append(float(C[i, j]))
        j = int(ptr[i, j])
        if j == -1 and i > 0:
            # reached start early (should only happen right before i==0)
            pass
    pairs.reverse(); pair_costs.reverse()
    return total_cost, pairs, pair_costs


class MT5Reward:
    def __init__(self, max_len=20000, model_name=MODEL_NAME, proj_dim=None, freeze=True, batch_size=16, metric='randomNegCos', sentence_level='mean+randomNegCos', margin=1, device=DEVICE):
        self.mt5 = MT5FeatureExtractor(model_name, proj_dim, freeze).to(device)
        self.max_len = max_len
        self.batch_size = batch_size
        self.metric = metric
        self.sentence_level = sentence_level
        self.margin = margin
        self.split = "<|im_start|>assistant\n<|im_start|>think"


    def get_sentence_level_reward(self, prompts, completions, **kwargs):
        """
        Get reward for two texts.
        Returns a tuple (reward1, reward2) where each is a tensor of shape (1, D).
        """
        
        if 'translated_prompts' in kwargs:
            prompt = kwargs['translated_prompts']
            kwargs['original'] = completions
            completions = kwargs['translated_completions']

        if 'text' in kwargs:
            original = []
            for text in kwargs['text']:
                if self.split in text:
                    original.append(text.split(self.split)[-1])
                else:
                    original.append(text)
        else:
            original = prompts

        c_sentences = [split_fn(text)[0] for text in completions]
        o_sentences = [split_fn(text)[0] for text in original]
        # record original length
        L_c = [len(sents) for sents in c_sentences]
        L_o = [len(sents) for sents in o_sentences]
        # flatten all setentences to feed.
        flat_c = [sent for sents in c_sentences for sent in sents]
        flat_o = [sent for sents in o_sentences for sent in sents]
        if len(flat_c) == 0:
            print("There is no Generated Sentences")
            return [0] * len(completions)
        
        c_embeddings = self.mt5.encode(flat_c, max_len=self.max_len)
        o_embeddings = self.mt5.encode(flat_o, max_len=self.max_len)
        # split embeddings by original length
        c_splits = torch.split(c_embeddings, L_c)
        o_splits = torch.split(o_embeddings, L_o)

        if 'random' in self.sentence_level:
            # shuffle with the same order.
            max_length = max(max(L_o), max(L_c))
            order = np.arange(max_length)
            np.random.shuffle(order)
            random_flat_c = []
            random_flat_o = []
            for sentences in flat_c:
                random_flat_c.append("".join([sentences[i] for i in order if i < len(sentences)]))
            for sentences in flat_o:
                random_flat_o.append("".join([sentences[i] for i in order if i < len(sentences)]))
            rc_embeddings = self.mt5.encode(random_flat_c, max_len=self.max_len)
            ro_embeddings = self.mt5.encode(random_flat_o, max_len=self.max_len)
            rc_splits = torch.split(rc_embeddings, L_c)
            ro_splits = torch.split(ro_embeddings, L_o)

        rewards = []
        
        for i, (c_emb, o_emb) in enumerate(zip(c_splits, o_splits)):
            # do matching
            if 'cosine' in self.sentence_level:
                sim_matrix = torch.nn.functional.cosine_similarity(
                    c_emb.unsqueeze(1), o_emb.unsqueeze(0), dim=-1, eps=1e-8)
                if 'Norm' in self.sentence_level:
                    sim_matrix = (sim_matrix + 1) / 2 # set range [0, 1]
            elif 'euclidean' in self.sentence_level:
                d_matrix = torch.cdist(c_emb, o_emb, p=2)
                sim_matrix = -d_matrix
            elif 'randomNeg' in self.sentence_level:
                if 'Cos' in self.sentence_level: 
                    pos_sim_matrix = torch.nn.functional.cosine_similarity(
                        c_emb.unsqueeze(1), o_emb.unsqueeze(0), dim=-1, eps=1e-8)
                    neg_sim_matrix = torch.nn.functional.cosine_similarity(
                        rc_splits[i].unsqueeze(1), ro_splits[i].unsqueeze(0), dim=-1, eps=1e-8)
                    sim_matrix = pos_sim_matrix - neg_sim_matrix
                else:
                    pos_d_matrix = torch.cdist(c_emb, o_emb, p=2)
                    neg_d_matrix = torch.cdist(rc_splits[i], ro_splits[i], p=2)
                    sim_matrix = neg_d_matrix - pos_d_matrix
                if 'NoMargin' in self.sentence_level:
                    if 'Norm' in self.sentence_level:
                        sim_matrix = sim_matrix / 4 + 0.5
                else:
                    sim_matrix = torch.clamp(sim_matrix + self.margin, min=0)
                    if 'Norm' in self.sentence_level:
                        sim_matrix = sim_matrix / 3
            else:
                raise Exception(f"Unknown sentence_level metric: {self.sentence_level}")
            # matching with ordered
            
            total_cost, pairs, pair_costs = matching(-sim_matrix)
            if 'mean' in self.sentence_level:
                reward = -total_cost / (len(c_emb) + 1e-4)
            elif 'both' in self.sentence_level:
                reward = -total_cost / (len(c_emb) + 1e-4) -torch.Tensor(pair_costs)
            else: 
                reward = -torch.Tensor(pair_costs)
            rewards.append(reward)

        print("Sentence-level rewards:", rewards)
        if 'mean' in self.sentence_level:
            rewards = torch.tensor(rewards)
            return rewards.cpu().tolist()
        else:
            return (rewards, c_sentences) # list of tensors



    def get_reward(self, prompts, completions, **kwargs):
        """
        Get reward for two texts.
        Returns a tuple (reward1, reward2) where each is a tensor of shape (1, D).
        """

        if 'translated_prompts' in kwargs:
            prompt = kwargs['translated_prompts']
            kwargs['original'] = completions
            completions = kwargs['translated_completions']

        st = time.time()
        if 'text' in kwargs:
            original = []
            for text in kwargs['text']:
                if self.split in text:
                    original.append(text.split(self.split)[-1])
                else:
                    original.append(text)
        else:
            original = prompts

        combined = completions + original
        embeddings = self.mt5.encode(combined, max_len=self.max_len, batch_size=self.batch_size)
        trans_embedding = embeddings[:len(completions)]
        original_embedding = embeddings[len(completions):]

        if self.metric == 'cosine':
            # measure cosine similarity
            reward = F.cosine_similarity(trans_embedding, original_embedding, dim=-1, eps=1e-8)
        elif self.metric == 'euclidean':
            # negative L2.
            reward = -F.pairwise_distance(trans_embedding, original_embedding, p=2, eps=1e-8)
        elif self.metric == 'triplet':
            # random shuffle each string in original
            roriginal = []
            for ori in original:
                li = list(ori)
                np.random.shuffle(li)
                roriginal.append("".join(li))
            random_embedding = self.mt5.encode(roriginal, max_len=self.max_len, batch_size=self.batch_size)
            reward = F.triplet_margin_loss(trans_embedding, original_embedding, random_embedding, margin=1.0, p=2, eps=1e-8, reduction='none')
        elif 'randomNeg' in self.metric:
            rcompletions, roriginal = self.shuffle(completions, original)
            combined = rcompletions + roriginal
            embeddings = self.mt5.encode(combined, max_len=self.max_len, batch_size=self.batch_size)
            rtrans_embedding = embeddings[:len(rcompletions)]
            roriginal_embedding = embeddings[len(rcompletions):]
            if 'Cos' in self.metric:
                dpos = F.cosine_similarity(trans_embedding, original_embedding, dim=-1, eps=1e-8)
                dneg = F.cosine_similarity(rtrans_embedding, roriginal_embedding, dim=-1, eps=1e-8)
            else:
                dpos = F.pairwise_distance(trans_embedding, original_embedding, p=2, eps=1e-8) 
                dneg = F.pairwise_distance(rtrans_embedding, roriginal_embedding, p=2, eps=1e-8)
#            reward = torch.max(dpos - dneg + self.margin, torch.zeros_like(dpos))  # max(0, d_pos - d_neg)
            if 'NoMargin' in self.metric:
                reward = dpos - dneg
                if 'Norm' in self.metric: # set range to [0, 1] 
                    reward = reward / 4 + 0.5 # maximum value of reward
            else:
                reward = torch.clamp(dpos - dneg + self.margin, min=0) 
                if 'Norm' in self.metric: # set range to [0, 1] 
                    reward = reward / 3 # maximum value of reward
            print(dpos, dneg, reward)
        else:
            raise Exception(f"Unknown metric: {self.metric}")
        print(f"MT5 Reward : {reward}, Time used: {time.time() - st:.2f}s")
        return reward.cpu().tolist()
        

    def shuffle(self, completions, original): 
        rcompletions = []
        roriginal = []
        for comp, ori in zip(completions, original):
            L = max(len(comp), len(ori))
            order = np.arange(L)
            np.random.shuffle(order)
            rcompletions.append("".join([comp[i] for i in order if i < len(comp)]))
            roriginal.append("".join([ori[i] for i in order if i < len(ori)]))
        return rcompletions, roriginal


class MT5FeatureExtractor(nn.Module):
    def __init__(self, model_name=MODEL_NAME, proj_dim=None, freeze=True):
        super().__init__()
#        self.tok = MT5Tokenizer.from_pretrained(model_name, use_fast=True)
        self.model_name = model_name
        if 'gemma' in model_name:
            from sentence_transformers import SentenceTransformer
            self.enc = SentenceTransformer("google/embeddinggemma-300m")
        else:
            self.tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
            self.enc = MT5EncoderModel.from_pretrained(model_name)

        if freeze:
            for p in self.enc.parameters():
                p.requires_grad = False
        self.proj = None

    @torch.no_grad()
    def encode(self, texts, max_len=512, batch_size=16):
        """
        Returns a tensor of shape (N, D) where D = hidden_size or proj_dim.
        Mean-pooling over non-pad tokens.
        """
        self.eval()
        embs = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            if 'gemma' in self.model_name:
                pooled = self.enc.encode_document(batch)
                embs.append(torch.from_numpy(pooled))
            else:
                tok = self.tok(
                    batch, padding=True, truncation=True, max_length=max_len, return_tensors="pt",
                ).to(self.enc.device)
                out = self.enc(**tok)  # last_hidden_state: (B, T, H)
                last = out.last_hidden_state
                mask = tok.attention_mask.unsqueeze(-1)  # (B, T, 1)
                # mean-pool (exclude pads)
                pooled = (last * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9)
                if self.proj is not None:
                    pooled = self.proj(pooled)
                embs.append(pooled.detach().cpu())
        return torch.cat(embs, dim=0)

    def forward(self, input_ids, attention_mask):
        """Train-time forward (returns pooled embeddings)."""
        out = self.enc(input_ids=input_ids, attention_mask=attention_mask)
        last = out.last_hidden_state
        mask = attention_mask.unsqueeze(-1)
        pooled = (last * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9)
        return self.proj(pooled) if self.proj is not None else pooled

