import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType, LoraModel
from torch.nn.functional import cosine_similarity, cross_entropy
from torch.utils.data import DataLoader, Subset
from probe_for_lora import AttnEmotionProbe
from dataset_classes import synth_text_dataset, synth_dataset_path
from sklearn.model_selection import train_test_split
from collections import defaultdict
import copy
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
from torch.amp import GradScaler, autocast
from contextlib import nullcontext
from torch.nn.functional import cosine_similarity
import torch.nn.init as init
import math
from train_emotion_classifier import EmotionClassifier
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
import gc

def run_epoch_ce(model, dataloader, adapter_name,
                             optimizer=None, scheduler=None, mode="train", epoch=0,
                             α=0.5, β=0.1, γ=0.1, exp=1.25, patience=100):

    model.train() if mode == "train" else model.eval()
    total_loss, token_ce_losses, latent_ce_losses, semantic_losses = 0.0, 0.0, 0.0, 0.0
    best_loss, steps_since_improvement = float("inf"), 0
    device = model.device
    scaler = GradScaler()
    context = autocast(dtype=torch.float16, device_type="cuda") if mode == "train" else torch.no_grad()

    weights = torch.cat([
        torch.ones(len(tokenizer), device=device),
        torch.zeros(model.lm_head.out_features - len(tokenizer), device=device)
    ])

    weights[target_ids + [j for i in synonym_token_set.values() for j in i]] = 20
    margin1 = 0.5
    margin2 = 10

    token_ce_loss_fn = CrossEntropyLoss(weight=weights, ignore_index=tokenizer.pad_token_id)
    emotion_ce_loss_fn = CrossEntropyLoss()
    use_semantic_loss = True

    tqdm_iter = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"{mode.title()} Epoch")
    for i, batch in tqdm_iter:
        try:
            with context:
                user_input_ids = batch["text_wo_answers_id"].to(device)  # [B, T_user]
                user_mask = batch["text_wo_answers_attention_mask"].to(device)  # [B, T_user]
                full_input_ids = batch["input_ids"].to(device)  # [B, T_full]
                full_mask = batch["attention_mask"].to(device)  # [B, T_full]
                target_emotion = torch.tensor(label2id[adapter_name]).repeat(user_input_ids.shape[0]).to(device)  # [B]

                if use_semantic_loss:
                    disable_spectral_adapters(model)
                    # model.disable_adapters()
                    with torch.no_grad():
                        base_out = model.forward(full_input_ids, attention_mask=full_mask, output_hidden_states=True, use_cache=False)
                        h_base = base_out.hidden_states  # [B, T_user, D]
                        # del base_out
                        # torch.cuda.empty_cache()
                    enable_spectral_adapters(model)
                    # model.enable_adapters()

                    # full_out = checkpoint(full_forward, full_input_ids, full_mask) if mode == "train" and full_input_ids.requires_grad else full_forward(full_input_ids, full_mask)
                    full_out = model.forward(full_input_ids, attention_mask=full_mask, labels=full_input_ids, output_hidden_states=True, use_cache=False)
                    h_shifted = full_out.hidden_states
                    cos_sim = cosine_similarity(h_shifted[-1], h_base[-1], dim=-1).mean()
                    delta_norm = (h_shifted[-1] - h_base[-1]).norm(dim=-1)[full_mask].mean()
                    # semantic_loss = (1.0 - cos_sim) + γ * delta_norm ** 1.25
                    semantic_loss = (1.0 - cos_sim) + γ * delta_norm ** exp


                else:
                    # patience = 30
                    # with context:
                    user_out = model.forward(user_input_ids, attention_mask=user_mask, output_hidden_states=True, use_cache=False)

                    h_shifted_user = user_out.hidden_states[-1]  # [B, T_user, D]

                    # === Classifier loss on shifted hidden ===
                    logits_all = classifier(h_shifted_user, user_mask)
                    logits = (logits_all["logits_mean"] + logits_all["logits_first"] + logits_all["logits_last"]) / 3
                    temp = 0.5
                    logits = logits / temp

                    latent_ce_loss = emotion_ce_loss_fn(logits, target_emotion)

                    full_out = model.forward(full_input_ids, attention_mask=full_mask, labels=full_input_ids, output_hidden_states=True, use_cache=False)


                logits = full_out.logits  # [B, T_full, V]
                token_ce_loss = token_ce_loss_fn(logits.view(-1, logits.size(-1)), full_input_ids.view(-1))
                # canonical_tid = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(target_emotion)[0])
                # synonym_tids = [tid for tid in synonym_token_set[target_emotion] if tid != canonical_tid]
                distractor_tids = list(set().union(*[v for k, v in synonym_token_set.items() if k != target_emotion]))

                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                canon_logit = log_probs[..., targetted_logits[adapter_name]]  # [B, T]
                syn_logit = log_probs[..., synonym_token_set[adapter_name]].mean(dim=-1)  # [B, T]
                distractor_logit = log_probs[..., distractor_tids].logsumexp(dim=-1)  # [B, T]

                tier1 = torch.nn.functional.relu(margin1 - (canon_logit - syn_logit)).mean()  # Canon > synonyms
                tier2 = torch.nn.functional.relu(margin2 - (canon_logit - distractor_logit)).mean()  # Synonyms > distractors

                margin_loss = tier1 + tier2
                token_loss = token_ce_loss + 0.2 * margin_loss

                # === Total loss ===
                if use_semantic_loss:
                    loss = token_loss + β * semantic_loss
                else:
                    loss = token_loss + α * latent_ce_loss


            if mode == "train":
                scaler.scale(loss).backward()
                scaler.step(optimizer)

                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()

            total_loss += loss.item()
            token_ce_losses += token_ce_loss.item()
            if use_semantic_loss:
                semantic_losses += semantic_loss.item()
                postfix_dict = {
                    "token_loss": token_ce_loss.item(),
                }
                last_token_logit = full_out.logits.mean(dim=(0, 1))
                postfix_dict.update({word: last_token_logit[id].item() for word, id in targetted_logits.items()})
                tqdm_iter.set_postfix(postfix_dict)

            else:
                latent_ce_losses += latent_ce_loss.item()
                tqdm_iter.set_postfix({
                    "token_loss": token_ce_loss.item(),
                    "latent_ce_loss": latent_ce_loss.item(),
                    "logits_min": logits.min().item(),
                    "logits_max": logits.max().item(),
                    "step_loss": loss.item()
                })
                if latent_ce_loss < 1e-1:
                    use_semantic_loss = True

            # Early stopping check
            if mode == "train":
                if loss.item() < best_loss - 1e-2:
                    best_loss = loss.item()
                    steps_since_improvement = 0
                else:
                    steps_since_improvement += 1
                    if steps_since_improvement >= patience:
                        print(f"Early stopping at step {i}, loss plateaued.")
                        break
        except torch.OutOfMemoryError as e:
            model.save_pretrained(f"{adapter_save_path}_epoch_{epoch + 1}")
            tokenizer.save_pretrained(f"{adapter_save_path}_epoch_{epoch + 1}")
            print(f"OOM at step {i}, batch seq_len: {batch['input_ids'].shape[1]}")
            torch.cuda.empty_cache()
            continue


    return {
        "loss": total_loss / len(dataloader),
        "token_loss": token_ce_losses / len(dataloader),
        "latent_ce_loss": latent_ce_losses / len(dataloader),
        "semantic_loss": semantic_losses / len(dataloader),
    }


def make_stratified_subset(dataset, label_key="label", n_per_class=667):
    labels = [ex[label_key].item() if isinstance(ex[label_key], torch.Tensor) else ex[label_key] for ex in dataset]
    _, stratified_indices = train_test_split(
        range(len(dataset)),
        stratify=labels,
        train_size=n_per_class * len(set(labels)),
    )
    return Subset(dataset, stratified_indices)


class PostActivationLoRA(torch.nn.Module):
    def __init__(self, base_layer: torch.nn.Linear, steering_vector, Vh, r: int = 8, alpha: float = 1.0, bias: bool = False):
        super().__init__()
        self.base = base_layer
        device = base_layer.weight.device
        in_dim, out_dim = base_layer.in_features, base_layer.out_features
        self.lora_A = torch.nn.Linear(out_dim, r, bias=bias).to(device)
        self.lora_B = torch.nn.Linear(r, out_dim, bias=bias).to(device)
        self.scaling = alpha / r
        self.r = r
        self.enabled = True
        self.reset_parameters(steering_vector, Vh)

    def reset_parameters(self, steering_vector, Vh):
        with torch.no_grad():
            self.lora_A.weight.copy_(Vh[:self.r])
            # self.lora_B.weight.copy_((torch.diag(steering_vector[:self.r]) @ Vh[:self.r]).T)
            self.lora_B.weight.copy_((Vh[:self.r]).T)

    def forward(self, x):
        base_out = self.base(x)
        if not self.enabled:
            return base_out
        delta = self.lora_B(self.lora_A(base_out)) * self.scaling
        return base_out + delta

    def enable(self): self.enabled = True
    def disable(self): self.enabled = False


class SubspaceShiftLoRA(torch.nn.Module):
    def __init__(self, base_layer: torch.nn.Linear, steering_vector, Vh: torch.Tensor, r: int = 8, alpha: float = 1.0):
        super().__init__()
        self.base = base_layer
        device = base_layer.weight.device
        Vh_slice = Vh[:r].to(device)
        self.register_buffer("Vh", Vh_slice)
        self.lora_shift = torch.nn.Parameter(torch.randn(r, device=device) * math.sqrt(2 / r))
        self.scaling = alpha / r
        self.enabled = True
        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        base_out = self.base(x)
        if not self.enabled:
            return base_out
        proj = torch.matmul(base_out, self.Vh.T)         # [B, T, r]
        shifted = proj + self.lora_shift.view(1, 1, -1)                      # broadcast add [r]
        delta = torch.matmul(shifted, self.Vh)           # [B, T, D]
        return base_out + delta * self.scaling

    def enable(self): self.enabled = True
    def disable(self): self.enabled = False


class SubspaceLoRALinear(torch.nn.Module):
    def __init__(self, base_layer: torch.nn.Linear, steering_vector, Vh: torch.Tensor, r: int = 8, alpha: float = 1.0):
        super().__init__()
        self.base = base_layer
        device = base_layer.weight.device
        self.Vh = Vh[:r].to(device)
        r = min(self.Vh.shape[0], r)
        self.scaling = alpha / r
        self.enabled = True
        # self.lora_mlp = torch.nn.Sequential(
        #     torch.nn.Linear(r, 2*r, bias=True),
        #     torch.nn.GELU(),
        #     torch.nn.Linear(2*r, r, bias=True)
        # ).to(device)
        self.lora_mlp = torch.nn.Sequential(
            torch.nn.Linear(r, r, bias=True),
            torch.nn.GELU(),
        ).to(device)
        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        base_out = self.base(x)
        if not self.enabled:
            return base_out
        proj = torch.matmul(base_out, self.Vh.T)
        transformed = self.lora_mlp(proj)
        delta = torch.matmul(transformed, self.Vh)
        return base_out + delta * self.scaling

    def enable(self): self.enabled = True
    def disable(self): self.enabled = False



def enable_spectral_adapters(model: torch.nn.Module):
    for name, module in model.named_modules():
        if isinstance(module, PostActivationLoRA) or isinstance(module, SubspaceShiftLoRA) or isinstance(module, SubspaceLoRALinear):
            module.enable()
            for param in module.parameters():
                param.requires_grad = True
        elif isinstance(module, torch.nn.LayerNorm):
            for param in module.parameters():
                param.requires_grad = True

    for name, param in model.named_parameters():
        if "lora" not in name and "norm" not in name.lower():
            param.requires_grad = False


def disable_spectral_adapters(model: torch.nn.Module):
    for module in model.modules():
        if isinstance(module, PostActivationLoRA) or isinstance(module, SubspaceShiftLoRA) or isinstance(module, SubspaceLoRALinear):
            module.disable()
            for param in module.parameters():
                param.requires_grad = False


def replace_with_lora_wrappers(model: torch.nn.Module, target_emotion, target_modules: list[str], steering_vectors, Vh, r=8, alpha=1.0):
    def recursive_replace(module, prefix=""):
        for name, child in list(module.named_children()):
            full_name = f"{prefix}.{name}" if prefix else name
            if full_name in target_modules and isinstance(child, torch.nn.Linear):
                remapped_name = full_name.split(".")[-1] + "_" + full_name.split(".")[2]
                wrapped = SubspaceLoRALinear(child, r=r, steering_vector=steering_vectors[remapped_name][target_emotion], Vh=Vh[remapped_name]["Vh"], alpha=alpha)
                setattr(module, name, wrapped)
            else:
                recursive_replace(child, full_name)

    recursive_replace(model)


def initialize_peft_lora_with_geometry(peft_model, layer_name, λ_target, V_k):
    """
    peft_model: a PEFT-wrapped model (with get_peft_model)
    layer_name: e.g., "model.layers.24.mlp.down_proj"
    λ_target: [k] tensor
    V_k: [k, D] PCA basis
    """
    # Access the LoRA module
    lora_module = dict(peft_model.named_modules())[layer_name + f".lora_A.{adapter_name}"]
    B_module = dict(peft_model.named_modules())[layer_name + f".lora_B.{adapter_name}"]

    # Sanity checks
    A = lora_module.weight  # shape: [r, D]
    B = B_module.weight     # shape: [D_out, r]
    r = A.shape[0]

    if λ_target.shape[0] < r:
        raise ValueError(f"LoRA rank {r} > λ_target dim {λ_target.shape[0]}")

    # Truncate or pad λ_target and V_k to match r
    λ = λ_target[:r].to(A.device) * 0.15
    V = V_k[:r].T.to(B.device)

    # Override LoRA weights
    with torch.no_grad():
        noise = torch.empty_like(A)
        init.kaiming_uniform_(noise, a=math.sqrt(5))  # or normal_(0, σ)
        # A.copy_(torch.nn.functional.normalize(λ.unsqueeze(1) + noise, dim=-1))  # [r, 1] + [r, D]
        A.copy_(λ.unsqueeze(1) + noise)  # [r, 1] + [r, D]
        B.copy_(V)                     # [D_out, r]


if __name__ == "__main__":
    # --- Config ---
    parser = argparse.ArgumentParser()
    parser.add_argument("--adapter_name", type=str, default="happy")
    # anger, disgust, envy, excitement, fear, happy, neutral, sad, surprise
    args = parser.parse_args()

    r = 40
    # r = 20
    # model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    # model_name = "mistralai/Ministral-8B-Instruct-2410"  # needs to be the same as the one in dataset_classes.py
    model_name = "allenai/OLMo-2-1124-7B-Instruct"  # possible that this won't work because the subspace was built with enable_thinking=True
    adapter_name = args.adapter_name
    adapter_save_path = f"emotion_loras/olmo_base_space_in_instruct_{adapter_name}_latent_space_shift_tokence_{r}"
    target_layers = [i for i in range(1, 32)]
    epochs = 4
    multi_layer = True  # Set False to use only final hidden state
    last_layer = False
    batch_size = 8

    # --- Setup ---
    # steering_vectors = torch.load("steering_vectors.pt", weights_only=False)
    steering_vectors = torch.load("steering_vectors_olmo_base.pt", weights_only=False)
    import IPython; IPython.embed()
    # kl_divergence_stats = torch.load("kl_divergence_stats.pt")
    # manifolds_and_centroids = torch.load("hidden_state_emotional_data_no_lora.pt", weights_only=False)
    # manifolds_and_centroids = torch.load("hidden_state_synth_no_lora_qwen_with_exclusions_non_thinking.pt", weights_only=False)
    manifolds_and_centroids = torch.load("hidden_state_synth_no_lora_olmo_base_with_exclusions.pt", weights_only=False)
    # emotion_target_modules = {
    #     emotion: [
    #         f"model.layers.{ln}.{('mlp' if '_'.join(pn) in ['down_proj', 'gate_proj', 'up_proj'] else 'self_attn')}.{'_'.join(pn)}"
    #         for l in layers.keys()
    #         for pn, ln in [(l.split("_")[:-1], l.split("_")[-1])]
    #     ]
    #     for emotion, layers in steering_vectors.items()
    # }
    target_modules = [f"model.layers.{i.split('_')[-1]}.{'mlp' if '_'.join(i.split('_')[0:2]) in ['up_proj', 'down_proj', 'gate_proj'] else 'self_attn'}.{'_'.join(i.split('_')[0:2])}" for i in steering_vectors.keys()]
    # target_modules = emotion_target_modules

    # lora_config = LoraConfig(r=r, lora_alpha=4, target_modules=target_modules, bias="none", task_type=TaskType.CAUSAL_LM)

    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    replace_with_lora_wrappers(model, adapter_name, target_modules, steering_vectors=steering_vectors, Vh=manifolds_and_centroids["manifolds"], r=r, alpha=16)


    # model = LoraModel(model,lora_config, adapter_name=adapter_name)
    # model.add_adapter(adapter_name=adapter_name, adapter_config=lora_config)
    # model.set_adapter(adapter_name)
    # [initialize_peft_lora_with_geometry(model, "model."+target_module, steering_vectors[target_module.split(".")[-1] + "_" + target_module.split(".")[2]][adapter_name], manifolds_and_centroids["manifolds"][target_module.split(".")[-1] + "_" + target_module.split(".")[2]]["Vh"]) for target_module in tqdm(target_modules)]
    model.train()

    prompt_template = [
        {"role": "user", "content": 'What emotion is expressed in the sentence: "{text}"?\n\nChoices: {emotion_list}\n\nAnswer:'},
        {"role": "assistant", "content": f"{adapter_name}"}
    ]

    synth_dataset = synth_text_dataset(synth_dataset_path, tokenizer, prompt_template=prompt_template, N=5000)
    label2id = {label: i for i, label in enumerate(sorted(set(synth_dataset.dataset["label"])))}
    id2label = {i: label for label, i in label2id.items()}
    synth_dataset.dataset["label"] = [label2id[lbl] for lbl in synth_dataset.dataset["label"]]

    target_words = [emotions for emotions in label2id.keys()]

    target_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(" ".join(target_words)))
    targetted_logits = {word: id for word, id in zip(target_words, target_ids)}
    synonym_token_set = {e: tokenizer.convert_tokens_to_ids(tokenizer.tokenize(" ".join(synth_dataset.emotion_synonyms[e]["words"][1:]))) for e in target_words}


    grouped = defaultdict(list)
    for i, idx in enumerate(synth_dataset.dataset["original_index"]):
        grouped[idx].append(i)
    unique_ids = list(grouped.keys())
    train_ids, val_ids = train_test_split(unique_ids, test_size=0.15, random_state=42)
    train_indices = [i for idx in train_ids for i in grouped[idx]]
    val_indices = [i for idx in val_ids for i in grouped[idx]]
    synth_dataset_copy = copy.deepcopy(synth_dataset)
    train_dataset = Subset(synth_dataset_copy, train_indices)
    val_dataset = Subset(synth_dataset_copy, val_indices)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=synth_text_dataset.make_collate_fn(tokenizer.pad_token_id), pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=synth_text_dataset.make_collate_fn(tokenizer.pad_token_id), pin_memory=True)


    classifier = EmotionClassifier(hidden_dim=4096, num_classes=len(label2id)).to(model.device)
    classifier.load_state_dict(torch.load("best_classifier.pt"))

    for name, param in model.named_parameters():
        if "lora_" in name or "shift" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

    # params = [(n, p) for n, p in model.named_parameters()
    #           if p.requires_grad and ("lora" in n or "norm" in n or "layernorm" in n.lower() or "shift" in n.lower())]
    params = [(n, p) for n, p in model.named_parameters()
              if p.requires_grad and ("lora" in n)]
    optimizer = torch.optim.AdamW([p for _, p in params], lr=1e-3, weight_decay=1e-2)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=50,
        num_training_steps=len(train_dataloader),
    )    #


    # --- Training loop ---
    for epoch in range(1):
        train_res = run_epoch_ce(model, train_dataloader, adapter_name, optimizer=optimizer, scheduler=scheduler, mode="train")
        # model.save_pretrained(f"{adapter_save_path}_epoch_{epoch+1}")
        model.save_pretrained(f"{adapter_save_path}_epoch_{epoch+1}", safe_serialization=True)
        tokenizer.save_pretrained(f"{adapter_save_path}_epoch_{epoch+1}")
        print(f"Saved adapter to {adapter_save_path}_epoch_{epoch+1}")

