import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob

from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR

from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 
from datasets import load_from_disk
from torch.cuda.amp import autocast, GradScaler
from models import StreamingSafetyHead
import math
from transformers import get_cosine_schedule_with_warmup


from tqdm import tqdm
import random
import numpy as np
import tempfile


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cuda.matmul.allow_tf32 = False

    print(f"Random seed set globally to {seed}")


def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return trainable_params / 1_000_000


def remove_none_fields(d):
    if isinstance(d, dict):
        return {k: remove_none_fields(v) for k, v in d.items() if v is not None}
    elif isinstance(d, list):
        return [remove_none_fields(i) for i in d]
    else:
        return d


def find_sequence(lst, seq):
    n = len(seq)
    for i in range(len(lst) - n + 1):
        if lst[i:i+n] == seq:
            return i
    return -1


class SafetyDataset(Dataset):
    """
    Per-sample cached dataset:
    - cache_dir: {dataset_dir}/safety_cache/{model_name}/idx{idx_layer}maxlength{max_length}/
    - sample files: sample{orig_idx:08d}.pt, containing:
    {'embeddings': Tensor[seq, hidden], 'assistant_start': int, 'labels': Tensor[T_assistant]}
    - Build only on rank=0, others wait for barrier and just read.
    """
    def __init__(self,
        dataset_dir,
        model_name,
        tokenizer=None,
        base_model=None,
        idx_layer: int = 20,
        max_length: int = 4096,
        device: str = "cpu",
        build_cache_if_missing: bool = False,
        overwrite: bool = False,
        max_build_samples: int | None = None, # None=all
        debug_limit: int | None = None
        ):
        self.dataset_dir = dataset_dir
        self.model_name = model_name
        self.idx_layer = idx_layer
        self.max_length = max_length
        self.device = device

        # initialize the cache_dir and num_supervised_token
        self.user_prompt_marker = [151645, 198, 151644, 77091, 198]
        self.assistant_end = -4
        self.num_supervised_token = 10
        self.cache_dir = os.path.join(
                    dataset_dir,
                    f"safety_cache/{model_name.replace('/', '-')}/idx{idx_layer}_maxlength{max_length}"
                )
        os.makedirs(self.cache_dir, exist_ok=True)
    
        need_build = (len(glob.glob(os.path.join(self.cache_dir, "sample_*.pt"))) == 0)
        if need_build and build_cache_if_missing:
            assert tokenizer is not None and base_model is not None, "Building cache requires tokenizer and base_model."
            self._build_cache_per_sample(
                tokenizer=tokenizer,
                base_model=base_model,
                overwrite=overwrite,
                max_build_samples=max_build_samples
            )
    
        self.files = sorted(glob.glob(os.path.join(self.cache_dir, "sample_*.pt")))
        if debug_limit is not None:
            self.files = self.files[:debug_limit]
    
        if len(self.files) == 0:
            raise FileNotFoundError(f"No cached samples found in {self.cache_dir}. "
                                    f"Set build_cache_if_missing=True on rank=0 to build first.")


    def _build_cache_per_sample(self, tokenizer, base_model, overwrite=False, max_build_samples=None):
        print(f"Building per-sample cache into {self.cache_dir} ...")
        data = load_from_disk(self.dataset_dir)
        total = len(data) if max_build_samples is None else min(len(data), max_build_samples)
     
        base_model.eval()
        with torch.no_grad():
            for i in tqdm(range(total), desc="Build samples"):
                sample_path = os.path.join(self.cache_dir, f"sample_{i:08d}.pt")
                if (not overwrite) and os.path.exists(sample_path):
                    continue
    
                info = data[i]
                if dataset in ["mmsafety", "figstep"]:
                    info['messages'] = remove_none_fields(info['messages'])
                text = tokenizer.apply_chat_template(
                    info['messages'][:2],
                    tokenize=False,
                    add_generation_prompt=True,
                    max_length=self.max_length,
                    truncation=True
                )
                model_inputs = tokenizer([text], return_tensors="pt").to(self.device)
                if dataset in ['mmsafety', 'figstep']:
                    label = int(info['messages'][-1]['content'][0]['text'])
                else:
                    label = int(info['messages'][-1]['content'])
    
                if "Qwen2.5-Omni" in self.model_name:
                    output = base_model.generate(
                        **model_inputs,
                        thinker_max_new_tokens=1,
                        temperature=0,
                        top_p=1.0,
                        top_k=0,
                        do_sample=False,
                        repetition_penalty=1.0,
                        output_hidden_states=True,
                        return_dict_in_generate=True,
                        return_audio=False
                    )

                else:
                    output = base_model.generate(
                        **model_inputs,
                        max_new_tokens=1,
                        temperature=0,
                        top_p=1.0,
                        top_k=0,
                        do_sample=False,
                        repetition_penalty=1.0,
                        output_hidden_states=True,
                        return_dict_in_generate=True,
                    )

                hidden_states = output.hidden_states[0][self.idx_layer]  # (1, seq, hidden)
    
                user_to_assistant_pos = find_sequence(model_inputs.input_ids[0].tolist(), self.user_prompt_marker)
                if user_to_assistant_pos < 0:
                    continue
                assistant_start = user_to_assistant_pos + len(self.user_prompt_marker)
    
                seq_len = model_inputs.input_ids[:, assistant_start:self.assistant_end].shape[-1]
                if seq_len <= 0:
                    continue
    
                labels = torch.full((1, seq_len), -100, dtype=torch.long, device=self.device)
                labels[:, :self.num_supervised_token] = 0
                labels[:, -self.num_supervised_token:] = torch.tensor([label], device=self.device).unsqueeze(1).expand(-1, min(seq_len, self.num_supervised_token))
    
                embedding_cpu = hidden_states[0, :self.assistant_end, :].detach().cpu().contiguous()
                labels_cpu = labels[0].detach().cpu().contiguous()
    
                payload = {
                    "embeddings": embedding_cpu,          # (seq, hidden)
                    "assistant_start": int(assistant_start),
                    "labels": labels_cpu                   # (T_assistant,)
                }
    
                tmp_fd, tmp_path = tempfile.mkstemp(dir=self.cache_dir)
                os.close(tmp_fd)
                torch.save(payload, tmp_path)
                os.replace(tmp_path, sample_path)
    
        print(f"Cache build finished at {self.cache_dir}")
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        obj = torch.load(self.files[idx], map_location="cpu")
        embeddings = obj["embeddings"]            # (seq, hidden), cpu tensor
        assistant_start = int(obj["assistant_start"])
        labels = torch.as_tensor(obj["labels"], dtype=torch.long)  # (T_assistant)
        return {
            "embeddings": embeddings,
            "assistant_start": assistant_start,
            "labels": labels
        }

set_seed(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bf16 = True


modeling_method = "SLD"
loss_type = "ATC"
model_name = "Qwen/Qwen3-8B"
load_base_model = True
dataset = "seval"
idx_layer = 21
batch_size = 1
gradient_acc_steps = 32
max_length = 4096


if load_base_model:
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto"
    )
    base_model.eval()
    for p in base_model.parameters():
        p.requires_grad = False
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

else:
    base_model = None
    tokenizer = None


dataset = "seval"
train_dataset_dir = "data/s_eval/qwen3_8b/trainset/"
test_dataset_dir = "data/s_eval/qwen3_8b/testset/"

train_dataset = SafetyDataset(
        dataset_dir=train_dataset_dir, 
        tokenizer=tokenizer,
        base_model=base_model,
        model_name=model_name,
        device=device,
        idx_layer=idx_layer,
        max_length=max_length,
        build_cache_if_missing=True,
        overwrite=False,
        max_build_samples=None
        )
test_dataset = SafetyDataset(
        dataset_dir=test_dataset_dir, 
        tokenizer=tokenizer,
        base_model=base_model,
        model_name=model_name,
        device=device,
        idx_layer=idx_layer,
        max_length=max_length,
        build_cache_if_missing=True,
        overwrite=False,
        max_build_samples=None
        )
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

input_dim = AutoConfig.from_pretrained(model_name).hidden_size

del base_model
del tokenizer
if torch.cuda.is_available():
    torch.cuda.empty_cache()

eval_steps = gradient_acc_steps * 100


if modeling_method == "SLD":
    safety_head = StreamingSafetyHead(
            input_dim=input_dim,
            proj_dim=1024, 
            mem_dim=1024, 
            num_labels=2, 
            use_dt=True)

elif modeling_method == "MLP":
    # head of SheildHead
    safety_head = nn.Sequential(
            nn.Linear(input_dim, 16384, bias=False),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(16384, 1024, bias=False),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(1024, 8192, bias=False),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(8192, 2)
            )
else:
    print('---------Unrecognized modeling method-----------')
    sys.exit()

safety_head.to(device=device, dtype=torch.bfloat16)
safety_head.requires_grad = True

print("Total trainable parameters: ", count_parameters(safety_head), 'M')

optimizer = AdamW(
    safety_head.parameters(),
    lr=5e-5,
    weight_decay=0.1,
    betas=(0.9, 0.95),
    eps=1e-8
)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)
max_grad_norm = 1.0


num_train_epochs = 10 if dataset in ['mmsafety', 'figstep'] else 1
max_steps = -1
lr_scheduler_type = "cosine"
warmup_ratio = 0.05
warmup_steps = 0

num_update_steps_per_epoch = math.ceil(len(train_loader) / gradient_acc_steps)
if max_steps is None or max_steps < 0:
    total_training_steps = num_train_epochs * num_update_steps_per_epoch
else:
    total_training_steps = max_steps

if warmup_steps and warmup_steps > 0:
    computed_warmup_steps = warmup_steps
else:
    computed_warmup_steps = int(total_training_steps * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=computed_warmup_steps,
    num_training_steps=total_training_steps
)

save_dir = "./checkpoints_safety_head/"+model_name.replace('/', '-')+"/"+modeling_method+"/"+loss_type+"/"+str(idx_layer)+"/"+dataset+"/"
os.makedirs(save_dir, exist_ok=True)

def compute_temporal_smoothness_loss(logits, valid_mask=None, lam=0.01):
    # logits: (B, T, C)
    # valid_mask: (B, T) / bool
    if logits.size(1) < 2:
        return torch.zeros([], device=logits.device, dtype=logits.dtype)

    probs = F.softmax(logits, dim=-1)  # (B, T, C)
    diff = probs[:, 1:, :] - probs[:, :-1, :]  # (B, T-1, C)
    if valid_mask is not None:
        vm = valid_mask[:, 1:] & valid_mask[:, :-1]  # (B, T-1)
        vm = vm.unsqueeze(-1)  # (B, T-1, 1)
        diff = diff[vm.expand_as(diff)]
    return lam * (diff.pow(2).mean() if diff.numel() > 0 else torch.zeros([], device=logits.device, dtype=logits.dtype))


def compute_temporal_tv_monotone_loss(logits, valid_mask=None, lam_tv=0.01, lam_mono=0.01):
    # logits: (B, T, C)
    # valid_mask: (B, T) bool
    if logits.size(1) < 2:
        return torch.zeros([], device=logits.device, dtype=logits.dtype)

    p = torch.softmax(logits, dim=-1)[..., 1]  # (B, T)
    diffs = p[:, 1:] - p[:, :-1]               # (B, T-1)

    if valid_mask is not None:
        vm = valid_mask[:, 1:] & valid_mask[:, :-1]  # (B, T-1)
        diffs = diffs[vm]

    if diffs.numel() == 0:
        return torch.zeros([], device=logits.device, dtype=logits.dtype)

    tv = diffs.abs().mean()
    mono = torch.relu(-diffs).mean()

    return lam_tv * tv + lam_mono * mono



global_step = 0
completed_steps = 0
safety_head.train()

for epoch in range(num_train_epochs):
    total_loss = 0.0
    total_tokens = 0
    total_correct = 0

    optimizer.zero_grad(set_to_none=True)

    for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_train_epochs}")):
        assert batch["labels"].size(0) == 1, "Current implementation assumes batch_size_per_device=1 for streaming."
        labels = batch["labels"].to(device)  # (1, T_assistant)
        feat = batch['embeddings'].to(device)  # (seq, hidden) on CPU -> move to device

        assistant_start = batch['assistant_start']
        if isinstance(assistant_start, (list, tuple)):
            assistant_start = assistant_start[0]
        if isinstance(assistant_start, torch.Tensor):
            assistant_start = int(assistant_start.item())
        else:
            assistant_start = int(assistant_start)

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=bf16):

            if modeling_method == 'SLD':
                logits = safety_head(feat, assistant_start) # [Bs, N, D]
            else:
                logits = safety_head(feat[:,assistant_start:])

            if loss_type == "ATC":
                loss_ce = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

                anchor_mask = (labels != -100) # (B, T_assistant)
                reg_mask = torch.ones_like(anchor_mask, dtype=torch.bool)
                loss_smooth = compute_temporal_tv_monotone_loss(
                            logits, valid_mask=reg_mask, lam_tv=0.01, lam_mono=0.01
                            )

                loss = loss_ce + loss_smooth

            else:
                # ce loss on the last token only
                loss = criterion(logits[:,-1,:].view(-1, logits.size(-1)), labels[:,-1].view(-1))

            loss = loss / gradient_acc_steps

        loss.backward()

        with torch.no_grad():
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)  # (1, T_assistant)
            mask = (labels != -100)
            correct = (preds[mask] == labels[mask]).sum().item()
            total_correct += correct
            total_tokens += mask.sum().item()

        if (step + 1) % gradient_acc_steps == 0:
            torch.nn.utils.clip_grad_norm_(safety_head.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)

            completed_steps += 1
            global_step += 1

            current_lr = optimizer.param_groups[0]['lr']
            avg_loss = total_loss / gradient_acc_steps
            avg_acc = (total_correct / total_tokens) if total_tokens > 0 else 0.0
            print(f"Epoch [{epoch+1}/{num_train_epochs}], "
                  f"UpdateStep [{completed_steps}/{total_training_steps}], "
                  f"LR: {current_lr:.2e}, Loss: {avg_loss:.4f}, Acc(token): {avg_acc:.4f}")

            total_loss = 0.0
            total_correct = 0
            total_tokens = 0

            if max_steps is not None and max_steps > 0 and completed_steps >= max_steps:
                break

    if max_steps is not None and max_steps > 0 and completed_steps >= max_steps:
        print("Reached max_steps. Stopping training.")
        break

    ckpt_path = os.path.join(save_dir, f"model_epoch_{epoch}.pt")
    torch.save(safety_head.state_dict(), ckpt_path)
    print(f"Saved checkpoint: {ckpt_path}")

print("Training complete!")

