import os
import argparse
import json
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import umap
import matplotlib.pyplot as plt
import numpy as np
from sentence_transformers import SentenceTransformer
import random
import wandb
import random, numpy as np
import timm


parser = argparse.ArgumentParser(description="Train with selectable losses")
parser.add_argument("--losses",  type=str, default="ce",
                    help="Comma‑separated list of losses to use (ce,triplet,align,center,var)")
parser.add_argument("--project", type=str, default="ce",
                    help="WandB run name / logical experiment id")
parser.add_argument("--folder",  type=str, default=None,
                    help="Directory where checkpoints and logs are stored. Default: ./<project>")
parser.add_argument("--beta_mode", choices=["fixed", "learnable"], default="learnable",
                    help="fixed: fixed distribution; learnable: learnable soft-max vector ")

args = parser.parse_args()

SELECTED_LOSSES = [s.strip() for s in args.losses.split(",") if s.strip()]
assert SELECTED_LOSSES, "At least one loss must be selected"
project_name   = args.project
project_folder = args.folder if args.folder else f"./{project_name}"
os.makedirs(project_folder, exist_ok=True)


def save_path(name: str):
    return os.path.join(project_folder, name)

# === Determinism ===
SEED=42

def seed_everything(seed:int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False
    torch.use_deterministic_algorithms(False, warn_only=True)

seed_everything(SEED)

loader_gen = torch.Generator().manual_seed(SEED)

def seed_worker(worker_id:int):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed); random.seed(worker_seed); torch.manual_seed(worker_seed)

DEVICE = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
print("Device:", DEVICE)

# === Hyper‑parameters ===

BETA_INIT  = 0.5 # only if used wihtout learnable
BATCH_SIZE = 128
NUM_EPOCHS = 10000
LR         = 3e-4
PATIENCE   = 100

wandb.login(key="")
wandb.init(entity="", project="SL_CIFAR_CONVNEXT", name=project_name,
           config=dict(lr=LR, bs=BATCH_SIZE, selected=SELECTED_LOSSES))

# === Dataloader ===
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor()])
train_set = datasets.CIFAR100("./data", train=True,  download=True, transform=transform)
_test_set  = datasets.CIFAR100("./data", train=False, download=True, transform=transform)

split_gen = torch.Generator().manual_seed(SEED)
val_len = int(0.2*len(_test_set))
val_set, test_set = random_split(_test_set, [val_len, len(_test_set)-val_len], generator=split_gen)

train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True,  num_workers=0,
                          worker_init_fn=seed_worker, generator=loader_gen)
val_loader   = DataLoader(val_set,   BATCH_SIZE, shuffle=False, num_workers=0,
                          worker_init_fn=seed_worker, generator=loader_gen)
test_loader  = DataLoader(test_set,  BATCH_SIZE, shuffle=False,
                          worker_init_fn=seed_worker, generator=loader_gen)

with open("fine_to_coarse_cifar.json") as f:
    fine_to_coarse = json.load(f)
fine_classes = train_set.classes
coarse_classes = sorted(set(fine_to_coarse.values()))
coarse_to_id = {c: i for i, c in enumerate(coarse_classes)}
fine_idx_to_coarse_idx = {i: coarse_to_id[fine_to_coarse[c]] for i, c in enumerate(fine_classes)}

# === Sentence‑BERT ===
sem = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
fine_emb   = torch.tensor(sem.encode(fine_classes,   normalize_embeddings=True), device=DEVICE)
coarse_emb = torch.tensor(sem.encode(coarse_classes, normalize_embeddings=True), device=DEVICE)

# ---------------------------------------------------------------------------
#  Model
# ---------------------------------------------------------------------------
class ConvNeXtEncoder(nn.Module):
    def __init__(self, selected_losses, feat_dim=384, num_classes=200, beta_mode='fixed'):
        super().__init__()
        self.selected = selected_losses
        net = timm.create_model("convnext_tiny", pretrained=True, num_classes=0)
        self.backbone = net
        self.head = nn.Linear(net.num_features, feat_dim)
        self.fc = nn.Linear(feat_dim, num_classes)
        self.alpha = nn.Parameter(torch.ones(len(self.selected)))
        self._smx = nn.Softmax(dim=0)
        self.beta_mode = beta_mode

        if beta_mode == "learnable":
            # two logits → soft-max to [w_fine, w_coarse]
            self.beta_raw = nn.Parameter(torch.zeros(2))
        else:
            # register as buffer so it lives on the same device but no grad
            self.register_buffer("beta_const", torch.tensor(BETA_INIT))

    def forward(self, x):
        z = self.head(self.backbone(x))
        return self.fc(z), z

    def combine(self, loss_dict):
        w = self._smx(self.alpha)
        total = sum(w[i] * loss_dict[n] for i, n in enumerate(self.selected))
        return total, w.detach()
    
    def beta(self):
        if self.beta_mode == "learnable":
            return self._smx(self.beta_raw)
        else:
            return self.beta_const
# ---------------------------------------------------------------------------
#  Loss functions
# ---------------------------------------------------------------------------

def cosine_align(z_norm, fine_t, coarse_t, beta):
    loss_fine   = 1 - F.cosine_similarity(z_norm, fine_t,   dim=1).mean()
    loss_coarse = 1 - F.cosine_similarity(z_norm, coarse_t, dim=1).mean()

    if torch.is_tensor(beta) and beta.numel() == 2:
        return beta[0] * loss_fine + beta[1] * loss_coarse
    else:                              # scalar
        return (1-beta) * loss_fine + beta * loss_coarse

def center_regularizer(z_norm):
    return z_norm.mean(0).pow(2).mean()

def variance_regularizer(z_norm):
    return (z_norm.std(0)-1).pow(2).mean()

def triplet_semantic_loss(z, labels, fine_embeddings, margin=0.2):
    z=F.normalize(z,1); pos=F.normalize(fine_embeddings[labels],1); B=labels.size(0)
    neg_labels=torch.randint(0,fine_embeddings.size(0),(B,),device=labels.device)
    mask=neg_labels==labels
    while mask.any():
        neg_labels[mask]=torch.randint(0,fine_embeddings.size(0),(mask.sum(),),device=labels.device); mask=neg_labels==labels
    neg=F.normalize(fine_embeddings[neg_labels],1)
    return F.relu(1-F.cosine_similarity(z,pos)-(1-F.cosine_similarity(z,neg))+margin).mean()

# ---------------------------------------------------------------------------
#  Training setup
# ---------------------------------------------------------------------------
model = ConvNeXtEncoder(SELECTED_LOSSES, num_classes=len(fine_classes),beta_mode=args.beta_mode).to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr=LR)
ce_fn = nn.CrossEntropyLoss()

# logging
header = ["epoch"] + [f"a_{n}" for n in SELECTED_LOSSES]
if args.beta_mode == "learnable":
    header += ["b_fine","b_coarse"]
else:
    header += ["b_const"]
header += ["train_loss","train_acc","val_loss","val_acc"] + SELECTED_LOSSES
metrics_f = open(save_path("metrics.csv"),"w",newline=""); csvw=csv.writer(metrics_f); csvw.writerow(header); metrics_f.flush()

best_val=patience=0

for epoch in range(NUM_EPOCHS):
    # —— train ——
    model.train(); tot_loss=tot_corr=tot_seen=0
    for x,y in tqdm(train_loader, desc=f"{epoch+1}/{NUM_EPOCHS} train", leave=False):
        x,y=x.to(DEVICE),y.to(DEVICE); coarse_y=torch.tensor([fine_idx_to_coarse_idx[int(i)] for i in y],device=DEVICE)
        optim.zero_grad(); logits,z=model(x); z_norm=F.normalize(z,1)

        beta_val = model.beta()

        ld = {
            "ce"     : ce_fn(logits,y),
            "align"  : cosine_align(z_norm, F.normalize(fine_emb[y],1), F.normalize(coarse_emb[coarse_y],1), beta_val),
            "center" : center_regularizer(z_norm),
            "var"    : variance_regularizer(z_norm),
            "triplet": triplet_semantic_loss(z,y,fine_emb),
        }
        loss, alpha = model.combine({k:ld[k] for k in SELECTED_LOSSES}); loss.backward(); optim.step()
        n=x.size(0); tot_loss+=loss.item()*n; tot_corr+=(logits.argmax(1)==y).sum().item(); tot_seen+=n
    train_loss,train_acc=tot_loss/tot_seen,tot_corr/tot_seen

    # —— val ——
    model.eval(); v_loss=v_corr=v_seen=0; val_ld_sum={n:0. for n in SELECTED_LOSSES}
    with torch.no_grad():
        for x,y in val_loader:
            x,y=x.to(DEVICE),y.to(DEVICE); coarse_y=torch.tensor([fine_idx_to_coarse_idx[int(i)] for i in y],device=DEVICE)
            logits,z=model(x); z_norm=F.normalize(z,1)
            beta_val = model.beta() 
            ld={
                "ce":ce_fn(logits,y),
                "align":cosine_align(z_norm,F.normalize(fine_emb[y],1),F.normalize(coarse_emb[coarse_y],1),beta_val),
                "center":center_regularizer(z_norm),
                "var":variance_regularizer(z_norm),
                "triplet":triplet_semantic_loss(z,y,fine_emb),
            }
            lv,_ = model.combine({k:ld[k] for k in SELECTED_LOSSES}); n=x.size(0)
            v_loss+=lv.item()*n; v_corr+=(logits.argmax(1)==y).sum().item(); v_seen+=n
            for k in SELECTED_LOSSES: val_ld_sum[k]+=ld[k].item()*n
    val_loss,val_acc=v_loss/v_seen, v_corr/v_seen
    val_ld_mean=[val_ld_sum[k]/v_seen for k in SELECTED_LOSSES]
    
    alpha_epoch = torch.softmax(model.alpha, dim=0)   # <- tensor
    alpha_epoch = alpha_epoch.detach().cpu().tolist() # -> list[float]

    if args.beta_mode == "learnable":
        beta_log = model.beta().detach().cpu().tolist()      # [fine, coarse]
        beta_row = beta_log
        beta_dict = {"b_fine": beta_log[0], "b_coarse": beta_log[1]}
    else:
        beta_log = [BETA_INIT]
        beta_row = beta_log
        beta_dict = {"b_const": BETA_INIT}

    row = [epoch + 1,
       *alpha_epoch,             
       *beta_row,              
       train_loss, train_acc,
       val_loss,   val_acc,
       *val_ld_mean]         
    csvw.writerow(row); metrics_f.flush()
    wandb.log({
        "epoch":      epoch + 1,
        "train_loss": train_loss, "val_loss": val_loss,
        "train_acc":  train_acc,  "val_acc":  val_acc,
        **{f"a_{n}": a for n, a in zip(SELECTED_LOSSES, alpha_epoch)},
        **beta_dict,
        **{f"{n}_val": m for n, m in zip(SELECTED_LOSSES, val_ld_mean)}
    })

    alpha_str = " ".join([f"a_{n}={a:.3f}" for n, a in zip(SELECTED_LOSSES, alpha_epoch)])
    if args.beta_mode == "learnable":
        beta_str = f"b_fine={beta_row[0]:.3f} b_coarse={beta_row[1]:.3f}"
    else:
        beta_str = f"b_const={beta_row[0]:.3f}"
    comp_str  = " ".join([f"{n}_val={m:.4f}" for n, m in zip(SELECTED_LOSSES, val_ld_mean)])
    print(
        f"Ep{epoch+1:04d} | "
        f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} | "
        f"train_acc={train_acc*100:.2f}% val_acc={val_acc*100:.2f}% | "
        f"{alpha_str} | {beta_str} | {comp_str}"
)

    torch.save({
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optim.state_dict(),
    }, os.path.join(project_folder, f"model_{epoch:04d}.pth"))

    # early stop
    if val_acc>best_val: best_val=val_acc; patience=0
    else: patience+=1; 
    if patience>=PATIENCE: break

    # generate umap figure
    model.eval()
    latents_list, labels_list = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(DEVICE)
            _, embeddings = model(images)
            latents_list.append(embeddings.cpu())
            labels_list.append(labels)

    latents = torch.cat(latents_list)
    labels = torch.cat(labels_list)
    reducer = umap.UMAP(n_components=2, random_state=SEED)
    embeddings_2d = reducer.fit_transform(latents.numpy())

    plt.figure(figsize=(10, 8))
    plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels.numpy(), cmap='tab20', s=10)
    plt.colorbar()
    plt.savefig(save_path(f"umap_{epoch}.png"))
    plt.close()

metrics_f.close(); wandb.finish()
