import torch
from torch.optim import lr_scheduler
from torch.utils import data
from torch.utils.data import IterableDataset
from datasets import AbstractDataset
from utils import combine_logs
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm.auto import tqdm
import wandb
import hydra
from omegaconf import DictConfig, OmegaConf
from load_objs import load_item
import os, random
import numpy as np
import torch.nn.functional as F




def seed_everything(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True






@torch.no_grad()
def cheb_nodes(resolution, device):
    k = torch.arange(1, resolution + 1, device=device, dtype=torch.float32)
    x = torch.cos((2 * k - 1) * np.pi / (2 * resolution))
    return torch.flip(x, dims=[0])  # Sorted sequence from -1 to 1, shape [R]

@torch.no_grad()
def cheb_vander(x, max_degree):
    # x: [R] in [-1,1]
    R = x.numel()
    K = max_degree + 1
    T = torch.empty((R, K), device=x.device, dtype=torch.float32)
    T[:, 0] = 1.0
    if K > 1:
        T[:, 1] = x
    for k in range(2, K):
        T[:, k] = 2 * x * T[:, k - 1] - T[:, k - 2]
    return T  # [R,K]

@torch.no_grad()
def weighted_degree_from_coeffs(coeffs, norm=True):
    # coeffs: [..., K] includes c0
    K = coeffs.shape[-1]
    degrees = torch.arange(K, device=coeffs.device, dtype=coeffs.dtype)
    abs_c = coeffs.abs()
    num = (abs_c * degrees).sum(dim=-1)
    if norm:
        den = abs_c.sum(dim=-1).clamp_min(1e-12)
        return num / den
    else:
        return num

@torch.no_grad()
def pca_scores(samples_outputs_reshaped: torch.Tensor, k: int = 1):
    """
    Perform PCA individually for each pair and return the top-k principal component scores.

    Args:
        samples_outputs_reshaped: [P, R, C]
        k: target reduced dimension K (K<=min(R,C))
    Returns:
        scores: [P, R, K]
    """
    P, R, C = samples_outputs_reshaped.shape
    # k = int(k)
    # k = max(1, k)
    # k = min(k, R, C)

    out = []
    for i in range(P):
        X = samples_outputs_reshaped[i]                      # [R, C]
        Xc = X - X.mean(dim=0, keepdim=True)                 # center

        try:
            # Vh: [min(R,C), C]
            U, S, Vh = torch.linalg.svd(Xc, full_matrices=False)
            V = Vh[:k].T                                     # [C, K]
        except RuntimeError:
            V = torch.eye(C, device=X.device, dtype=X.dtype)[:, :k]  # fallback

        # Sign flip: Force the peak (max absolute value) of each principal component to be positive.
        # V: [C,K]
        for j in range(k):
            idx = torch.argmax(V[:, j].abs())
            V[:, j] = V[:, j] * torch.sign(V[idx, j].clamp_min(1e-12))

        scores = Xc @ V                                       # [R, K]
        out.append(scores)

    return torch.stack(out, dim=0)                             # [P, R, K]


@torch.no_grad()
def compute_wd_metrics(
    model,
    x_ids,               # [B,L]
    device,
    num_pairs=2,
    resolution=64,
    max_degree=40,
    use_pca=False,
    pca_k=1
):
    """
    Returns:
      wd_logits_norm   (for logits)
      wd_probs_nonorm  (for probs)
      mse_logits
      mse_probs
    """
    model.eval()
    x_ids = x_ids.to(device)
    B, L = x_ids.shape
    if B < 2:
        z = torch.tensor(0.0, device=device)
        return z, z, z, z

    # Sample pairs
    i1 = torch.randint(0, B, (num_pairs,), device=device)
    off = torch.randint(1, B, (num_pairs,), device=device)
    i2 = (i1 + off) % B

    # Chebyshev nodes: fit on [-1,1], interpolate using alpha01 in [0,1]
    x_cheb = cheb_nodes(resolution, device)          # [R] in [-1,1]
    alpha01 = (x_cheb + 1.0) * 0.5                   # [R] in [0,1]
    a = alpha01.view(1, resolution, 1, 1)            # [1,R,1,1]

    # Get continuous representation h0 (embed + pos + dropout) and interpolate
    h0 = model.embed_tokens(x_ids)                   # [B,L,D]
    h1 = h0[i1]                                      # [P,L,D]
    h2 = h0[i2]                                      # [P,L,D]
    h_interp = h1.unsqueeze(1) + a * (h2.unsqueeze(1) - h1.unsqueeze(1))   # [P,R,L,D]
    h_flat = h_interp.reshape(num_pairs * resolution, L, h0.size(-1))      # [P*R,L,D]

    # Forward pass to get predictions: [P*R,L,C]
    preds, _ = model.forward_from_embeds(h_flat)
    logits = preds[:, -1, :]                         # [P*R,C]
    C = logits.size(-1)
    logits = logits.view(num_pairs, resolution, C)   # [P,R,C]
    probs = F.softmax(logits, dim=-1)                # [P,R,C]


    if use_pca:
        logits = pca_scores(logits, k=pca_k)   # [P,R,K]
        probs  = pca_scores(probs,  k=pca_k)   # [P,R,K]


    # Chebyshev fitting
    T = cheb_vander(x_cheb, max_degree)              # [R,K]
    T_pinv = torch.linalg.pinv(T)                    # [K,R]
    coeffs_logits = torch.einsum("kr,prc->pkc", T_pinv, logits)  # [P,K,C]
    coeffs_probs  = torch.einsum("kr,prc->pkc", T_pinv, probs)   # [P,K,C]

    # wd: compute per class, then average (average over P and C dimensions)
    wd_logits_norm = weighted_degree_from_coeffs(coeffs_logits.permute(0, 2, 1), norm=True).mean()
    wd_probs_nonorm = weighted_degree_from_coeffs(coeffs_probs.permute(0, 2, 1), norm=False).mean()

    # mse: fitting error (helps judge if degree/resolution is sufficient)
    yhat_logits = torch.einsum("rk,pkc->prc", T, coeffs_logits)
    yhat_probs  = torch.einsum("rk,pkc->prc", T, coeffs_probs)
    mse_logits = ((yhat_logits - logits) ** 2).mean()
    mse_probs  = ((yhat_probs - probs) ** 2).mean()

    model.train()
    return wd_logits_norm, wd_probs_nonorm, mse_logits, mse_probs




def compute_ce_loss(model, x, y, device):
    # Use only CE to avoid additional logging overhead
    preds, _ = model(x.to(device))
    loss = F.cross_entropy(preds[:, -1, :], y.to(device))
    return loss

@torch.no_grad()
def _save_params(model):
    return [p.detach().clone() for p in model.parameters() if p.requires_grad]

@torch.no_grad()
def _restore_params(model, saved):
    idx = 0
    for p in model.parameters():
        if p.requires_grad:
            p.copy_(saved[idx])
            idx += 1

def _grad_norm(grads, eps=1e-12):
    # grads: list of tensors
    return torch.sqrt(sum((g.norm() ** 2) for g in grads)).clamp_min(eps)

def compute_sharpness_metrics(
    model,
    x, y,
    device,
    rho=0.05,
    adaptive=False,
    eps=1e-12,
):
    """
    Returns:
      base_loss
      perturbed_loss
      sharpness = perturbed_loss - base_loss
    """
    # Record original mode to avoid affecting external code
    was_training = model.training
    model.train()  # It is recommended to use train mode to measure sharpness (dropout is active in training mode)

    # Save parameters
    saved = _save_params(model)

    # 1) base loss + grad
    model.zero_grad(set_to_none=True)
    base_loss = compute_ce_loss(model, x, y, device)
    base_loss.backward()

    params = [p for p in model.parameters() if p.requires_grad]
    grads = [p.grad.detach() for p in params]

    # 2) Construct perturbation epsilon
    with torch.no_grad():
        if adaptive:
            scaled = [ (p.detach().abs() + 1.0) * g for p, g in zip(params, grads) ]
            n = _grad_norm(scaled, eps=eps)
            for p, s in zip(params, scaled):
                p.add_(rho * s / n)
        else:
            n = _grad_norm(grads, eps=eps)
            for p, g in zip(params, grads):
                p.add_(rho * g / n)


    # 3) perturbed loss (no gradient needed)
    model.zero_grad(set_to_none=True)
    with torch.no_grad():
        pert_loss = compute_ce_loss(model, x, y, device)

    # 4) restore
    _restore_params(model, saved)
    model.zero_grad(set_to_none=True)

    # Restore mode
    if not was_training:
        model.eval()

    sharp = (pert_loss - base_loss.detach()).detach()
    return base_loss.detach(), pert_loss.detach(), sharp



class GroupDataset(IterableDataset):
    def __init__(self, dataset: AbstractDataset, split: str):
        super(GroupDataset, self).__init__()
        assert split in {'train', 'val'}
        self.dataset = dataset
        self.split = split
        self.fetch_f = None
        if self.split == 'train':
            self.fetch_f = self.dataset.fetch_train_example
        elif self.split == 'val':
            self.fetch_f = self.dataset.fetch_val_example
        else:
            raise NotImplementedError

    def __iter__(self):
        return self

    def __next__(self):
        x, y, _ = self.fetch_f()
        return torch.tensor(x), torch.tensor(y)

def train(config):
    seed = int(config.get("seed", 0))
    seed_everything(seed)
    print('using config:', config)
    train_cfg = config['train']
    wandb_cfg = config['wandb']
    if wandb_cfg['use_wandb']:
        wandb.init(project=wandb_cfg['wandb_project'], config=config)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = load_item(config['dataset'])
    train_data = GroupDataset(dataset, 'train')
    val_data = GroupDataset(dataset, 'val')
    model = load_item(config['model'], dataset.n_vocab, dataset.n_out, device)
    model.train()
    train_dataloader = DataLoader(train_data, num_workers=train_cfg['num_workers'], batch_size=train_cfg['bsize'])
    val_dataloader = DataLoader(val_data, num_workers=train_cfg['num_workers'], batch_size=train_cfg['bsize'])
    optim = torch.optim.AdamW(model.parameters(), lr=train_cfg['lr'], 
                              weight_decay=train_cfg['weight_decay'], 
                              betas=train_cfg['betas'])
    lr_schedule = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda s: min(s / train_cfg['warmup_steps'], 1))
    step = 0
    for x, y in tqdm(train_dataloader):
        loss, logs = model.get_loss(x.to(device), y.to(device))
        optim.zero_grad()
        loss.backward()
        optim.step()
        lr_schedule.step()
        if (step+1) % train_cfg['eval_every'] == 0:
            model.eval()
            with torch.no_grad():
                all_val_logs = []
                for i, (val_x, val_y) in (enumerate(val_dataloader)):
                    if i >= train_cfg['eval_batches']:
                        break
                    _, val_logs = model.get_loss(val_x.to(device), val_y.to(device))
                    all_val_logs.append(val_logs)
            out_log = {'val': combine_logs(all_val_logs), 'train': combine_logs([logs]), 'step': (step+1), 
                       'lr': float(lr_schedule.get_last_lr()[0])}
            # Compute only during evaluation to avoid slowing down each step
            wd_cfg = config.get("wd", {})
            if wd_cfg.get("enabled", True):
                # Count evaluation times (starting from 1)
                eval_idx = (step + 1) // train_cfg["eval_every"]
                log_every = int(wd_cfg.get("log_every_eval", 1))
                if log_every <= 1 or (eval_idx % log_every == 0):
                    wd_logits, wd_probs, mse_logits, mse_probs = compute_wd_metrics(
                        model, x, device,
                        num_pairs=int(wd_cfg.get("num_pairs", 2)),
                        resolution=int(wd_cfg.get("resolution", 64)),
                        max_degree=int(wd_cfg.get("max_degree", 40)),
                        use_pca=bool(wd_cfg.get("use_pca", False)),
                        pca_k=int(wd_cfg.get("pca_k", 1)),
                    )
                    out_log["wd_logits_norm"] = float(wd_logits.item())
                    out_log["wd_probs_nonorm"] = float(wd_probs.item())
                    out_log["mse_logits_norm"] = float(mse_logits.item())
                    out_log["mse_probs_nonorm"] = float(mse_probs.item())
            sharp_cfg = config.get("sharpness", {})
            if sharp_cfg.get("enabled", True):
                eval_idx = (step + 1) // train_cfg["eval_every"]
                log_every = int(sharp_cfg.get("log_every_eval", 1))
                if log_every <= 1 or (eval_idx % log_every == 0):
                    rho = float(sharp_cfg.get("rho", 0.05))

                    base, pert, sharp = compute_sharpness_metrics(
                        model, x, y, device, rho=rho, adaptive=False
                    )
                    abase, apert, asharp = compute_sharpness_metrics(
                        model, x, y, device, rho=rho, adaptive=True
                    )

                    out_log["sharp_base_loss"] = float(base.item())
                    out_log["sharp_pert_loss"] = float(pert.item())
                    out_log["sharpness"] = float(sharp.item())

                    out_log["asharp_base_loss"] = float(abase.item())
                    out_log["asharp_pert_loss"] = float(apert.item())
                    out_log["adaptive_sharpness"] = float(asharp.item())

            print(out_log)
            if wandb_cfg['use_wandb']:
                wandb.log(out_log)
            model.train()
        step += 1
        if train_cfg['max_steps'] is not None and step >= train_cfg['max_steps']:
            break


@hydra.main(config_path="../config", config_name="train_grokk")
def main(cfg : DictConfig):
    cfg = OmegaConf.to_container(cfg)
    train(cfg)

if __name__ == "__main__":
    main()

