import os
import math
import random
from typing import List, Dict, Tuple, Iterable

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Sampler
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
import json
import tqdm


class MultiDescDataset(Dataset):
    def __init__(self, item2descs: Dict[str, List[str]], item_ids: List[str], k=3):
        self.item2descs = item2descs
        self.item_ids = item_ids
        self.k = k

    def __len__(self):
        return len(self.item_ids)

    def __getitem__(self, idx):
        item = self.item_ids[idx]
        descs = self.item2descs[item]
        if len(descs) >= self.k:
            sampled = random.sample(descs, self.k)
        else:
            sampled = random.choices(descs, k=self.k)
        return item, sampled


class Collator:
    def __init__(self, tokenizer, max_length=2048):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, batch):
        items, desc_groups = zip(*batch)  # [(item, [desc1,...])]
        texts = [d for group in desc_groups for d in group]  
        enc = self.tokenizer(
            texts,
            padding=True, truncation=True, max_length=self.max_length,
            return_tensors="pt"
        )
        labels = []
        for i, group in enumerate(desc_groups):
            labels.extend([i] * len(group))
        labels = torch.tensor(labels)
        return enc, labels



class BlockShuffleSampler(Sampler[int]):
    def __init__(self, indices: List[int], num_blocks: int):
        self.indices = list(indices)
        self.num_blocks = max(1, num_blocks)

    def __iter__(self) -> Iterable[int]:
        n = len(self.indices)
        block_size = math.ceil(n / self.num_blocks)
        blocks = [self.indices[i:i + block_size] for i in range(0, n, block_size)]
        random.shuffle(blocks)
        for blk in blocks:
            random.shuffle(blk)
            for i in blk:
                yield i

    def __len__(self):
        return len(self.indices)


class MeanPooler(nn.Module):
    def forward(self, last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1)
        return (last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)


class BiEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.pool = MeanPooler()
        self.logit_scale = nn.Parameter(torch.tensor(0.07).log())

    def forward(self, enc):
        out = self.encoder(**enc)
        emb = self.pool(out.last_hidden_state, enc["attention_mask"])
        emb = nn.functional.normalize(emb, dim=-1)
        return emb


def multi_positive_contrastive_loss(emb, labels):
    temperature = 0.07
    sim = emb @ emb.t()  # [N, N]
    sim = sim / temperature
    N = sim.size(0)

    mask = labels.unsqueeze(0) == labels.unsqueeze(1)
    mask.fill_diagonal_(False)

    log_prob = sim - torch.logsumexp(sim, dim=1, keepdim=True)
    loss = -(mask.float() * log_prob).sum(1) / mask.sum(1).clamp(min=1)
    return loss.mean()


def train(
    item2descs: Dict[str, List[str]],
    model_name="bert-base-uncased",
    output_dir="./ckpt_multidesc",
    num_blocks=5,
    batch_size=16,
    k=3,                
    epochs=3,
    lr=3e-5,
    warmup_ratio=0.1,
    max_length=2048,
    fp16=True,
    seed=42,
    plot_loss=True     
):
    random.seed(seed); torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    items = [k for k, v in item2descs.items() if len(v) >= 1]
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    dataset = MultiDescDataset(item2descs, items, k=k)

    indices = list(range(len(dataset)))
    sampler = BlockShuffleSampler(indices, num_blocks=num_blocks)
    collator = Collator(tokenizer, max_length=max_length)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler,
                        collate_fn=collator, drop_last=True)

    model = BiEncoder(model_name).to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=lr)
    num_training_steps = epochs * len(loader)
    num_warmup = int(warmup_ratio * num_training_steps)
    scheduler = get_linear_schedule_with_warmup(optim, num_warmup, num_training_steps)

    scaler = torch.cuda.amp.GradScaler(enabled=fp16)

    os.makedirs(output_dir, exist_ok=True)
    epoch_losses = []
    all_step_losses = []
    epoch_steps = []

    global_step = 0
    for ep in tqdm.tqdm(range(1, epochs + 1)):
        model.train()
        running_loss = 0.0

        for step, (enc, labels) in enumerate(loader, 1):
            for k_enc in enc:
                enc[k_enc] = enc[k_enc].to(device)
            labels = labels.to(device)

            with torch.cuda.amp.autocast(enabled=fp16):
                emb = model(enc)
                loss = multi_positive_contrastive_loss(emb, labels)

            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            optim.zero_grad(set_to_none=True)
            scheduler.step()

            running_loss += loss.item()
            all_step_losses.append(loss.item())
            global_step += 1

            if step % 10 == 0 or step == len(loader):
                print(f"[Epoch {ep}/{epochs}] Step {step}/{len(loader)} "
                      f"Loss {running_loss/step:.4f}")
        epoch_loss = running_loss / len(loader)
        epoch_losses.append(epoch_loss)
        epoch_steps.append(len(all_step_losses))
        print(f"[Epoch {ep}/{epochs}] Average Loss: {epoch_loss:.4f}")
        if plot_loss:
            import matplotlib.pyplot as plt
            plt.figure(figsize=(10, 6))
            start_idx = epoch_steps[-2] if len(epoch_steps) > 1 else 0
            end_idx = epoch_steps[-1]
            epoch_step_losses = all_step_losses[start_idx:end_idx]
            
            plt.plot(range(1, len(epoch_step_losses) + 1), epoch_step_losses, marker='.', linestyle='-', color='b')
            plt.title(f'Epoch {ep} Training Loss')
            plt.xlabel('Step')
            plt.ylabel('Loss')
            plt.grid(True)
            epoch_loss_plot_path = os.path.join(output_dir, f'epoch_{ep}_loss_curve.png')
            plt.savefig(epoch_loss_plot_path)
            print(f"Epoch {ep} Loss: {epoch_loss_plot_path}")
            plt.close()
        ckpt = os.path.join(output_dir, f"epoch{ep}")
        os.makedirs(ckpt, exist_ok=True)
        model.encoder.save_pretrained(ckpt)
        tokenizer.save_pretrained(ckpt)

    if plot_loss:
        import matplotlib.pyplot as plt
        
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, epochs + 1), epoch_losses, marker='o', linestyle='-', color='b')
        plt.title('Training Loss Curve (Average per Epoch)')
        plt.xlabel('Epoch')
        plt.ylabel('Average Loss')
        plt.grid(True)
        avg_loss_plot_path = os.path.join(output_dir, 'average_loss_curve.png')
        plt.savefig(avg_loss_plot_path)
        print(f"Loss: {avg_loss_plot_path}")
        plt.close()
        
        plt.figure(figsize=(12, 6))
        plt.plot(range(1, len(all_step_losses) + 1), all_step_losses, marker='.', linestyle='-', color='g', alpha=0.6)
       
        for i, step in enumerate(epoch_steps):
            plt.axvline(x=step, color='r', linestyle='--', alpha=0.3)
            plt.text(step, plt.ylim()[1] * 0.9, f'Epoch {i+1}', rotation=90)
            
        plt.title('Total Training Loss Curve (All Steps)')
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.grid(True)
        total_loss_plot_path = os.path.join(output_dir, 'total_loss_curve.png')
        plt.savefig(total_loss_plot_path)
        print(f"Loss Curve: {total_loss_plot_path}")
        plt.close()

    print("model path:", ckpt)


def read_data(filename):
    data = {}
    with open(filename, "r") as f:
        for line in f:
            data_result = json.loads(line)
            toolid = next(iter(dict(data_result).keys()))
            data_des = dict(data_result)[toolid]
            data_list = []
            for d in data_des.values():
                data_list.append(d)
            data[toolid] = data_list
            
    return data

if __name__ == "__main__":
    data = read_data("./WorkflowAgent/embedding_test/tool_desdata/tool_des.jsonl")

    train(
        item2descs=data,
        model_name="./WorkflowAgent/embedding_test/embedding_model/huggingface",
        output_dir="",
        num_blocks=3,
        batch_size=524,
        k=3,       
        epochs=1000,
        plot_loss=True
    )
