import torch
import torch.nn as nn
from chronos import ChronosPipeline, ChronosModel
from chronos import ChronosConfig, ChronosTokenizer
import pandas as pd
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import os
import logging
import random
import threading
import queue
import math
import time
import pandas as pd
import pyarrow.parquet as pq
import datetime
# import dgl
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader
from torch.multiprocessing import Process
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import SequentialSampler
from torch.optim.lr_scheduler import StepLR
from undecorated import undecorated
from types import MethodType
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import (
    StructType, StructField,
    StringType, ArrayType, FloatType, LongType
)
from cycler import cycler
from pyspark.sql.window import Window
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import argparse
import os

CHRONOSMODEL_PATH="ChronosModelPath"


class SeriesDataset(Dataset):
    def __init__(self, series_list):
        # each series is a Python list of floats (no -1)
        self.data = series_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        arr = np.array(self.data[idx], dtype=np.float32)
        return torch.from_numpy(arr)

class MyChronosPipeline(ChronosPipeline):
    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        """
        Use the parent `from_pretrained` method and replace the model with MyChronosModel.
        """
        # Call the parent `from_pretrained` method to get the pipeline
        pipeline = super().from_pretrained(*args, **kwargs)
        
        # Replace the model in the pipeline with MyChronosModel
        pipeline.model = MyChronosModel(
            config=pipeline.model.config,  # Use the same configuration
            model=pipeline.model.model,   # Use the underlying pretrained model
        )
        
        # Return the updated pipeline
        return pipeline


class GlobalQuantileBins(ChronosTokenizer):
    def __init__(self, boundaries: np.ndarray, config: ChronosConfig):
        self.config     = config
        # boundaries: 1D numpy array of length B+1: [-inf, q1, q2, ..., inf]
        # convert to torch tensor on CPU; will be moved in bucketize
        self.boundaries = torch.tensor(boundaries, dtype=torch.float32)

    def context_input_transform(self, context: torch.Tensor):
        # identical to uniform but with quantile boundaries
        context = context.to(dtype=torch.float32)
        attention_mask = ~torch.isnan(context) & (context >= 0.1)
        scaled = context 
        token_ids = torch.bucketize(
            context,
            self.boundaries.to(scaled.device),
            right=True
        )
        # shift by special tokens
        token_ids = token_ids + self.config.n_special_tokens
        token_ids.clamp_(0, self.config.n_tokens - 1)
        token_ids[~attention_mask] = self.config.pad_token_id

        # append EOS if needed
        if self.config.use_eos_token and self.config.model_type == 'seq2seq':
            eos = torch.full((context.shape[0],1), self.config.eos_token_id)
            mask_eos = torch.ones_like(eos, dtype=torch.bool)
            token_ids     = torch.cat([token_ids, eos], dim=1)
            attention_mask= torch.cat([attention_mask, mask_eos], dim=1)
        return token_ids, attention_mask, 1

    def label_input_transform(self, label: torch.Tensor, scale: torch.Tensor):
        # same binning on labels
        label = label.to(dtype=torch.float32)
        token_ids = torch.bucketize(
            label / scale.unsqueeze(-1),
            self.boundaries.to(label.device),
            right=True
        ) + self.config.n_special_tokens
        attention_mask = torch.ones_like(token_ids, dtype=torch.bool)
        if self.config.use_eos_token:
            eos = torch.full((label.shape[0],1), self.config.eos_token_id)
            token_ids      = torch.cat([token_ids, eos], dim=1)
            attention_mask = torch.cat([attention_mask, eos==eos], dim=1)
        return token_ids, attention_mask

    def output_transform(self, samples: torch.Tensor, scale: Optional[torch.Tensor]):
        b = self.boundaries.cpu().numpy()
        lowers = b[1:]
        idx = samples.long() - self.config.n_special_tokens-1
        idx = idx.clamp(0, len(lowers)-1)
        lower_limits = torch.tensor(lowers, dtype=torch.float32, device=samples.device)
        return lower_limits[idx]


class MyChronosModel(ChronosModel):
    @classmethod
    def from_pretrained(cls, *args, boundaries=None, **kwargs):
        pipe = super().from_pretrained(*args, **kwargs)
        pipe.model = MyChronosModel(config=pipe.model.config, model=pipe.model.model)
        if boundaries is not None:
            pipe.tokenizer = GlobalQuantileBins(boundaries, pipe.model.config)
        return pipe
    
    def forward_with_embeddings(
        self,
        input_ids: torch.Tensor,          
        attention_mask: torch.Tensor,     
        decoder_input_ids: torch.Tensor,  
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Teacher-forced pass using T5's encoder & decoder modules directly.
        Returns:
          - logits_per_step: [B, 1, T, V]
          - dec_hidden_per_step: [B, 1, T, D]
        """
        device = self.model.device
        input_ids      = input_ids.to(device)
        attention_mask = attention_mask.to(device).long()

        B, T = input_ids.size()

        # Shift-right decoder inputs: [PAD] + input_ids[:, :-1]
        dec_in = torch.full(
            (B, 1), self.model.config.pad_token_id,
            dtype=torch.long, device=device
        )
        dec_in = torch.cat([dec_in, input_ids[:, :-1]], dim=1)  # [B, T]
        dec_attn = (dec_in != self.model.config.pad_token_id).long()  # [B, T]

        # ----- Encoder -----
        enc_out = self.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        enc_h = enc_out.last_hidden_state  # [B, T, D]

        # ----- Decoder (causal mask handled internally by T5) -----
        dec_out = self.model.decoder(
            input_ids=dec_in,
            attention_mask=dec_attn,
            encoder_hidden_states=enc_h,
            encoder_attention_mask=attention_mask,
            use_cache=False,
            return_dict=True
        )
        dec_h = dec_out.last_hidden_state  # [B, T, D]

        # ----- LM head -----
        logits = self.model.lm_head(dec_h)  # [B, T, V]

        return logits.unsqueeze(1), dec_h.unsqueeze(1)

    
class ChronosBinPredictor(nn.Module):
    """
    Wraps ChronosPipeline to:
      1) prepare inputs,
      2) run Chronos to get per‐step logits over bins,
      3) expose logits and true token IDs,
      4) compute cross‐entropy loss.
    """
    def __init__(self, boundaries, pretrained_model_or_path: str, device: torch.device):
        super().__init__()
        # load Chronos T5 pipeline & swap in ChronosModel if needed
        self.pipeline: MyChronosPipeline = MyChronosPipeline.from_pretrained(
            pretrained_model_or_path,
            device_map={"": device},
            torch_dtype=torch.float32
        )
        self.pipeline.tokenizer = GlobalQuantileBins(boundaries, self.pipeline.model.config)
        # ensure model is on the right device
        orig_model = self.pipeline.model.to(device)
        self.device = device

        # if it’s not already our subclass, build one, copy weights, swap it in
        if isinstance(orig_model, MyChronosModel):
            self.model = orig_model
        else:
            # instantiate a MyChronosModel with the same config
            new_model = MyChronosModel(orig_model.config)
            # copy across all pretrained weights
            new_model.load_state_dict(orig_model.state_dict(), strict=True)
            # move it to device
            new_model.to(device)
            # swap it into the pipeline
            self.pipeline.model = new_model
            self.model = new_model

        # CE loss for bins
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, timeseries: torch.Tensor):
        context = self.pipeline._prepare_and_validate_context(context=timeseries)
        input_ids, attention_mask, _ = self.pipeline.tokenizer.context_input_transform(context)

        dummy = torch.empty((input_ids.size(0), 1), dtype=torch.long, device=self.device)
        logits, _ = self.model.forward_with_embeddings(
            input_ids=input_ids.to(self.device),
            attention_mask=attention_mask.to(self.device),
            decoder_input_ids=dummy
        )
        logits = logits.squeeze(1)                   # [B, T, V]

        # Next-token targets (shifted-left)
        targets = input_ids[:, 1:].to(self.device)   # [B, T-1]
        logits  = logits[:, :-1, :]                  # [B, T-1, V]

        # Valid positions for loss (reflects ≥0.1 threshold): both t and t+1 must be valid
        valid = attention_mask[:, 1:].to(self.device)  # [B, T-1]

        return logits, targets, valid

    def compute_loss(self, logits: torch.Tensor, target_ids: torch.Tensor, valid_mask: torch.Tensor):
        """
        logits:    [B, T, V]
        target_ids:[B, T]  long
        valid_mask:[B, T+1] bool (True where we had data; we drop the first step)
        """

        # flatten only the valid positions
        logits_flat  = logits[valid_mask]  # [N_valid, V]
        targets_flat = target_ids[valid_mask]                        # [N_valid]

        loss = self.ce_loss(logits_flat, targets_flat)
        return loss


def load_cleaned_series(parquet_root, min_len=10, max_len = None, limit=None, outbound_only = False):
    spark = SparkSession.builder \
        .appName("cleaned loader") \
        .config("spark.driver.memory","100g") \
        .config("spark.driver.maxResultSize","50g") \
        .getOrCreate()

    schema = StructType([
        StructField("ip", StringType(), False),
        StructField("node_features",
            ArrayType(ArrayType(FloatType()), False), False),
        StructField("edge_indices",
            ArrayType(ArrayType(LongType()), False), True),
    ])

    df = spark.read.option("recursiveFileLookup", "true").parquet(parquet_root)
    # Pick one inbound-like and one outbound-like column
    in_cols  = [c for c in df.columns if "inbound"  in c.lower()]
    out_cols = [c for c in df.columns if "outbound" in c.lower()]
    assert in_cols and out_cols, f"Need at least one inbound and one outbound column, got: {df.columns}"

    in_col  = sorted(in_cols,  key=len)[0]   
    out_col = sorted(out_cols, key=len)[0]  

    # Remove values <= 0.0 (this also drops -1 paddings)
    print(f"count of df before cleaning: {df.count()}")
    df_clean = df.select(
        F.when(F.col(in_col).isNotNull(),  F.expr(f"filter({in_col},  x -> x > 0.0)"))
        .otherwise(F.array().cast("array<double>")).alias("in_clean"),
        F.when(F.col(out_col).isNotNull(), F.expr(f"filter({out_col}, x -> x > 0.0)"))
        .otherwise(F.array().cast("array<double>")).alias("out_clean"),
    )

    df_clean = df_clean.filter(F.size("in_clean")  >= min_len) 
    if max_len is not None:
        df_clean = df_clean.select(
            F.expr(f"slice(in_clean,  1, {max_len})").alias("in_clean"),
            F.expr(f"slice(out_clean, 1, {max_len})").alias("out_clean"),
        )

    # Back to your RDD pair shape
    pairs = df_clean.select("in_clean", "out_clean").rdd.map(lambda r: (r["in_clean"], r["out_clean"]))

    inbound, outbound = zip(*pairs.collect())
    inbound, outbound = list(inbound), list(outbound)
    series = outbound if outbound_only else inbound
    return series

def collate_as_list(batch):
    """
    Custom collate function to convert a batch of tensors into a list.
    """
    # Convert each tensor in the batch to a list
    return batch


import numpy as np
from collections import Counter

def iterative_quantile_bins(list_of_lists, n_bins):
    """
    Build exactly n_bins+1 edges by:
    1. Flattening and sorting the data.
    2. Repeatedly computing all rem_bins quantile cut-points.
    3. If any quantile value repeats, adding all unique smaller cuts,
       truncating data to > repeating value, and looping.
    4. If no repeats, taking all remaining cuts at once.
    5. Finally appending the global maximum.
    """
    flat = np.sort(np.concatenate(list_of_lists))
    if flat.size == 0:
        return np.array([])

    edges = []
    data = flat.copy()
    remaining_bins = n_bins

    while remaining_bins > 1 and data.size > 0:
        # 1) quantile positions including 100%
        percents = np.arange(1, remaining_bins) * (100.0 / remaining_bins)
        cuts = np.percentile(data, percents, method='lower')
        # 2) detect first repeating cut
        seen = set()
        first_rep = None
        for v in cuts:
            if v in seen:
                first_rep = v
                break
            seen.add(v)

        if first_rep is None:
            edges.extend(cuts.tolist())
            remaining_bins = 0
            break
        else:
            # take all cuts smaller than the plateau
            smaller_cuts = [v for v in cuts if v < first_rep]
            edges.extend(smaller_cuts)
            edges.append(first_rep)
            remaining_bins -= (len(smaller_cuts)+1)
            # truncate data to values strictly greater than the plateau
            data = data[data > first_rep]

    # append the global max
    edges.append(flat[-1])
    if len(edges) > n_bins + 1:
        # if we have too many edges, truncate to n_bins + 1
        edges = edges[:n_bins + 1]
    elif len(edges) < n_bins + 1:
        # from left to right, find the non consecutive values, for each non consecutive value pair, take their mid point and add the midpoint to the list. repeat this till length of edges is n_bins + 1
        while len(edges) < n_bins + 1:
            new_edges = []
            for i in range(len(edges) - 1):
                if edges[i + 1] == edges[i] + 1:
                    # non consecutive
                    continue
                mid_point = (edges[i] + edges[i + 1]) / 2
                new_edges.append(mid_point)
                if len(edges) + len(new_edges) >= n_bins + 1:
                    break
            edges.extend(new_edges)
            print(len(edges))
            #sort the edges list
            edges = sorted(edges)
    return np.array(edges)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("parquet_root")
    parser.add_argument("--epochs",     type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--limit",      type=int, default=None)
    parser.add_argument("--train_frac", type=float, default=0.8,
                        help="Fraction of data to use for training")
    parser.add_argument("--max_len",   type=int, default=None,help="Maximum length of series to consider")
    parser.add_argument("--save_dir",   default="checkpoints",
                        help="Where to save best model")
    #add retrain argument which takes a string, it is not required
    parser.add_argument("--retrain", type=str, default=None)
    parser.add_argument("--isCSV", type=int, default=None)
    # model path (was referenced later but not previously defined)
    parser.add_argument("--model", type=str, default=CHRONOSMODEL_PATH,
                        help="Base Chronos model path to fine-tune")
    # Early stopping arguments
    parser.add_argument("--early_stop_patience", type=int, default=10,
                        help="Number of consecutive epochs without val loss improvement to trigger early stop")
    parser.add_argument("--early_stop_min_delta", type=float, default=1e-4,
                        help="Minimum relative improvement (absolute decrease) in val loss to reset patience")
    args = parser.parse_args()

    dist.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(minutes=45))
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = dist.get_world_size()
    rank       = dist.get_rank()
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")

    # 1) load & broadcast the full list of series
    if rank==0:
        if args.isCSV:
            # load from CSV
            df = pd.read_csv(args.parquet_root)
            all_series = df.values.tolist()
            # make test train split
            random.seed(42)
            split = int(len(all_series)*args.train_frac)
            train_list = all_series[:split]
            val_list   = all_series[split:]
        else:
            all_series = load_cleaned_series(args.parquet_root,
                                            min_len=10,
                                            max_len=args.max_len,
                                            limit=args.limit)
            # shuffle + split
            cuts = iterative_quantile_bins(all_series, 4096)
            pipeline_tmp = MyChronosPipeline.from_pretrained(
                    CHRONOSMODEL_PATH,
                    device_map={"": "cpu"},
                    torch_dtype=torch.float32
                )
            import copy
            config = copy.deepcopy(pipeline_tmp.model.config)
            del pipeline_tmp
            B = config.n_tokens - config.n_special_tokens
            del config
            boundaries = np.concatenate(([0], cuts[:B-2], [1e20]))
            #save boundaries as pickle in the save_dir
            os.makedirs(args.save_dir, exist_ok=True)
            with open(os.path.join(args.save_dir, "boundaries.pkl"), "wb") as f:
                import pickle
                pickle.dump(boundaries, f)
            
            random.seed(42)
            random.shuffle(all_series)
            split = int(len(all_series)*args.train_frac)
            train_list = all_series[:split]
            val_list   = all_series[split:]
    else:
        train_list = None
        val_list   = None

    to_bcast = [train_list, val_list, boundaries if rank==0 else None]
    dist.broadcast_object_list(to_bcast, src=0)

    train_list, val_list, boundaries = to_bcast
    print(f"Rank {rank}: {len(train_list)} train series, {len(val_list)} val series")

    # 2) build Datasets & DistributedSamplers
    train_ds = SeriesDataset(train_list)
    val_ds   = SeriesDataset(val_list)

    train_sampler = DistributedSampler(train_ds, num_replicas=world_size,
                                       rank=rank, shuffle=True)
    val_sampler   = DistributedSampler(val_ds,   num_replicas=world_size,
                                       rank=rank, shuffle=False)

    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        sampler=train_sampler,
        collate_fn=collate_as_list,
        num_workers=2,
        pin_memory=True,
        drop_last=True
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        sampler=val_sampler,
        collate_fn=collate_as_list,
        num_workers=2,
        pin_memory=True,
        drop_last=False
    )

    # 3) model + DDP
    model = ChronosBinPredictor(boundaries, args.model, device).to(device)
    if args.retrain:
        # load the model state dict from the given path
        state = torch.load(args.retrain, map_location=device)
        state = strip_prefix_if_present(state, ["model.model.", "model."])
        model.model.load_state_dict(state, strict=False)
        print(f"Loaded model from {args.retrain}")
    
    model = DDP(model, device_ids=[local_rank], output_device=local_rank)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    #add a learning rate scheduler
    scheduler = StepLR(optimizer, step_size=100, gamma=0.99)

    # 4) training + validation loop
    best_acc = 0.0  # still track accuracy for logging
    best_val_loss = float('inf')
    epochs_no_improve = 0
    os.makedirs(args.save_dir, exist_ok=True)

    for epoch in range(args.epochs):
        # — train —
        model.train()
        count = 0
        train_sampler.set_epoch(epoch)
        total_loss = 0.0
        for batch in train_loader:
            # batch is a list of tensors
            try:
                logits, targets, valid_mask = model.module.forward(batch)
                loss = model.module.compute_loss(logits, targets, valid_mask)
                optimizer.zero_grad()
                loss.backward()
                #print losss every 100 steps
                if rank == 0 and count % 100 == 0:
                    print(f"Epoch {epoch+1}/{args.epochs} step {count} — CE loss: {loss.item():.4f}")
                optimizer.step()
                scheduler.step()  # step the scheduler
                total_loss += loss.cpu().item()
                del loss
                count += 1
            except:
                print(f"Skipping batch {count} due to error")
                print(f"Batch shape: {[t.shape for t in batch]}")
                continue

        # — validate —
        model.eval()
        val_sampler.set_epoch(epoch)
        correct = torch.tensor(0, dtype=torch.long, device=device)
        total   = torch.tensor(0, dtype=torch.long, device=device)
        # For validation loss (weighted by number of valid tokens)
        val_loss_sum = torch.tensor(0.0, dtype=torch.float32, device=device)
        val_tokens   = torch.tensor(0.0, dtype=torch.float32, device=device)

        with torch.no_grad():
            for batch in val_loader:
                logits, targets, valid_mask = model.module.forward(batch)
                preds = logits.argmax(dim=-1)
                mask  = valid_mask
                # accuracy stats
                correct += ((preds.detach().to(device) == targets.detach().to(device)) & mask.detach().to(device)).sum()
                total   += mask.detach().to(device).sum()
                # validation loss (compute per-batch, multiply by token count for global weighting)
                batch_loss = model.module.compute_loss(logits, targets, valid_mask)
                n_valid = mask.sum().to(torch.float32)
                val_loss_sum += batch_loss.detach() * n_valid
                val_tokens   += n_valid

        # aggregate across all ranks
        dist.all_reduce(correct, op=dist.ReduceOp.SUM)
        dist.all_reduce(total,   op=dist.ReduceOp.SUM)
        dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(val_tokens,   op=dist.ReduceOp.SUM)

        if rank == 0:
            acc = (correct.float() / total.float()).item()
            val_loss = (val_loss_sum / val_tokens).item() if val_tokens.item() > 0 else float('inf')
            print(f"Epoch {epoch+1}/{args.epochs} — val accuracy: {acc:.4f} | val loss: {val_loss:.6f}")

            # Track best accuracy separately (optional)
            if acc > best_acc:
                best_acc = acc

            # Early stopping logic based on validation loss
            improved = (best_val_loss - val_loss) > args.early_stop_min_delta
            if improved:
                best_val_loss = val_loss
                epochs_no_improve = 0
                path = os.path.join(args.save_dir, "chronos_best.pt")
                torch.save(model.module.state_dict(), path)
                print(f"→ New best model (val_loss={val_loss:.6f}) saved to {path}")
            else:
                epochs_no_improve += 1
                print(f"No improvement in val loss (best {best_val_loss:.6f}). Patience {epochs_no_improve}/{args.early_stop_patience}.")
                if epochs_no_improve >= args.early_stop_patience:
                    print("Early stopping triggered.")
                    break

    dist.destroy_process_group()


def strip_prefix_if_present(state_dict, prefixes):
    out = {}
    for k, v in state_dict.items():
        new_k = k
        for p in prefixes:
            if k.startswith(p):
                new_k = k[len(p):]
                break
        out[new_k] = v
    return out


import argparse
import pickle
import torch
import pandas as pd
from chronos import ChronosPipeline
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType, LongType

# -------- utility: DDP init / cleanup ----------
def ddp_init():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)  # 1 GPU per rank
    device = torch.device(f"cuda:{rank}")
    return rank, world_size, device

def ddp_cleanup():
    if dist.is_initialized():
        dist.barrier()
        dist.destroy_process_group()

# --------- simple helpers ----------
def shard_list_by_rank(x, rank, world_size):
    # deterministic sharding without extra comms
    return [v for i, v in enumerate(x) if (i % world_size) == rank]

def save_on_rank0(obj, path, rank):
    if rank == 0:
        with open(path, "wb") as f:
            pickle.dump(obj, f)

# -----------------------------------
# Your original function, upgraded
# -----------------------------------

def pick_series_col_by_substring(df, override=None):
    if override is not None:
        if override not in df.columns:
            raise RuntimeError(f"--series_col={override} not in parquet columns {sorted(df.columns)}")
        return override

    # collect candidates that contain 'inbound' (avoid columns we create later like *_filtered)
    lowered = {c.lower(): c for c in df.columns}
    candidates = [orig for low, orig in lowered.items()
                  if "inbound" in low and not low.endswith("_filtered")]

    if not candidates:
        raise RuntimeError(f"No column containing 'inbound' found in {sorted(df.columns)}")

    # priority: exact 'inbound' -> exact 'iei_inbound' -> shortest name
    if "inbound" in lowered:
        return lowered["inbound"]
    if "iei_inbound" in lowered:
        return lowered["iei_inbound"]
    return min(candidates, key=len)

def inference_auto_regression():
    parser = argparse.ArgumentParser()
    parser.add_argument("parquet_root", help="Root path of your Parquet series data")
    parser.add_argument("--ips_csv", default=None)
    parser.add_argument("--save_pkl", default="ar_inference_compressed2.pkl")
    parser.add_argument("--model", default="pretrained model path")
    parser.add_argument("--nonhierarchical", action="store_true")
    parser.add_argument("--ctx_frac", type=float, default=0.70,
                    help="Fraction of each series used as context; remainder is AR horizon")
    parser.add_argument("--min_ctx", type=int, default=10,
                        help="Minimum context steps per series after filtering")
    parser.add_argument("--min_h", type=int, default=1,
                        help="Minimum AR steps per series")
    parser.add_argument("--batch_size", type=int, default=32)
    args = parser.parse_args()

    # -------- DDP setup --------
    rank, world_size, device = ddp_init()
    print(args.save_pkl.split("."))
    os.makedirs(args.save_pkl.split(".")[0], exist_ok=True)

    # -------- Load model/pipeline on this rank's GPU --------

    pipeline = MyChronosPipeline.from_pretrained(
        CHRONOSMODEL_PATH,
        device_map=None  # we'll pin manually
    )
    pipeline.model.to(device)
    pipeline.model.eval()

    import copy
    config = copy.deepcopy(pipeline.model.config)
    B = config.n_tokens - config.n_special_tokens
    del config

    with open(os.path.join(args.model, "boundaries.pkl"), "rb") as f:
        boundaries = pickle.load(f)

    state = torch.load(os.path.join(args.model, "chronos_best.pt"), map_location=device)
    state = strip_prefix_if_present(state, ["model.model.", "model."])
    predictor = ChronosBinPredictor(
        boundaries,
        CHRONOSMODEL_PATH,
        device
    ).to(device)

    # load fine-tuned T5 weights (inner model only)
    state = torch.load(os.path.join(args.model, "chronos_best.pt"), map_location=device)
    state = strip_prefix_if_present(state, ["model.model.", "model."])
    predictor.model.model.load_state_dict(state, strict=True)
    predictor.model.eval()
    print(type(predictor.pipeline.tokenizer))


    # --------- Load series (Spark) ---------

    def load_min20(parquet_root, allowed_ips, nonHierarchical=False):
        spark = SparkSession.builder \
            .appName(f"Inference Data loader{rank}") \
            .config("spark.driver.memory", "200g") \
            .config("spark.driver.maxResultSize", "200g") \
            .getOrCreate()        
        

        # 2) Load parquet and normalize types
        df = spark.read.option("recursiveFileLookup", "true").parquet(parquet_root)
        if args.ips_csv is not None:
            if {"ip", "service_port"}.issubset(allowed_pd.columns):
                key_cols = ["ip", "service_port"]
                allowed_pd["service_port"] = allowed_pd["service_port"].astype("int64")
            elif "ip" in allowed_pd.columns:
                key_cols = ["ip"]
            elif "subnet" in allowed_pd.columns:
                key_cols = ["subnet"]
            else:
                raise ValueError("CSV must have header 'ip', or 'ip,service_port', or 'subnet'.")

            allowed_pd = pd.read_csv(args.ips_csv)
            allowed_pd.columns = [c.strip().lower() for c in allowed_pd.columns]
            allowed_sdf = spark.createDataFrame(allowed_pd[key_cols]).dropDuplicates()
            if "service_port" in df.columns:
                df = df.withColumn("service_port", F.col("service_port").cast("int"))

            if key_cols == ["subnet"] and "subnet" not in df.columns and "ip" in df.columns:
                df = df.withColumn("subnet", F.regexp_extract("ip", r"^(\d+\.\d+\.\d+)\.\d+$", 1))

            missing = [c for c in key_cols if c not in df.columns]
            if missing:
                raise RuntimeError(f"Parquet is missing key columns: {missing}")

            df = df.join(allowed_sdf, on=key_cols, how="inner")
        else:
            if "service_port" in df.columns:
                key_cols = ["ip", "service_port"]
            elif "ip" in df.columns:
                key_cols = ["ip"]
            elif "subnet" in df.columns:
                key_cols = ["subnet"]
            else:
                raise ValueError("Parquet must have column 'ip', or 'ip' and 'service_port', or 'subnet'.")        
        
        # --- UNIQUE KEYS ---
        uniq_keys_df = df.select(*key_cols).dropDuplicates().orderBy(*key_cols)

        def pack_key_row(r):
            return r[key_cols[0]] if len(key_cols) == 1 else tuple(r[c] for c in key_cols)

        unique_keys = [pack_key_row(r) for r in uniq_keys_df.collect()]

        # Pick series column by substring and filter
        series_col = pick_series_col_by_substring(df, override=getattr(args, "series_col", None))
        is_iei = "iei" in series_col.lower()
        flt = f"filter({series_col}, x -> x != -1.0{' AND x > 0.0' if is_iei else ''})"

        df = (
            df.withColumn(
                "series_filtered",
                F.when(F.col(series_col).isNotNull(), F.expr(flt))
                .otherwise(F.array().cast("array<double>"))
            )
            .filter(F.size("series_filtered") >= 1)
            .select(*(key_cols + ["series_filtered"]))
            .orderBy(*key_cols)
        )
        # After filtering & selecting series_filtered:
        df = df.select(*(key_cols + ["series_filtered"]))

        # Option A: keep the longest series per key (deterministic)
        df = df.withColumn("series_len", F.size("series_filtered"))
        w = Window.partitionBy(*key_cols).orderBy(F.col("series_len").desc())
        df_unique = (df.withColumn("rn", F.row_number().over(w))
                    .filter(F.col("rn") == 1)
                    .drop("rn", "series_len"))
        rows = (df_unique
                .orderBy(*key_cols)
                .collect())

        def pack_key(r):
            return r[key_cols[0]] if len(key_cols) == 1 else tuple(r[c] for c in key_cols)

        series = [(pack_key(r), r["series_filtered"]) for r in rows]
        return series

    # Everyone loads (simple & robust). Then we shard deterministically.
    if rank == 0:
        all_series = load_min20(args.parquet_root, args.ips_csv, nonHierarchical=args.nonhierarchical)
        shards = [all_series[i::world_size] for i in range(world_size)]
        input_list = shards                    # length == world_size
    else:
        input_list = []                        # non-src MUST pass empty list

    # Every rank prepares a 1-slot output list
    out = [None]
    input_list = shards if rank == 0 else []   # src supplies length==world_size, others []
    torch.distributed.scatter_object_list(out, input_list, src=0)
    my_series = out[0]
    torch.distributed.barrier()

    # --------- Batched autoregression on this rank ---------
    bs = args.batch_size

    def split_70_30(length, ctx_frac=0.70, min_ctx=10, min_h=1):
        """Return (L, H) for a series of given length using 70/30 split with guards."""
        L = max(min_ctx, int(math.ceil(ctx_frac * length)))
        H = length - L
        if H < min_h:
            take_back = min(min_ctx, (min_h - H))
            L = max(min_ctx, L - take_back)
            H = length - L
        if H <= 0 and length > min_ctx:
            L = length - 1
            H = 1
        return L, max(0, H)

    def ar_one_series(series_vals, device, predictor, pipeline, ctx_frac, min_ctx, min_h):
        T = len(series_vals)
        if T < (min_ctx + min_h):
            return [], [], 0, 0  # too short to evaluate

        L, H = split_70_30(T, ctx_frac=ctx_frac, min_ctx=min_ctx, min_h=min_h)
        H = min(H, 20)
        if H <= 0:
            return [], [], 0, 0

        # Build initial context tensor [1, L]
        ctx_np = np.array(series_vals[:L], dtype=np.float32)[None, :]  # shape (1, L)
        ctx_t  = torch.from_numpy(ctx_np).cpu()

        forecasts = []

        with torch.inference_mode():
            #calculate accuracy of bin predictions
            correct_bins = 0
            total_bins = 0

            for _ in range(min(H, 100)):
                logits, targets, _ = predictor.forward(ctx_t)            
                argmax_ids = logits.argmax(dim=-1)                  
                last_id    = argmax_ids[0, -1] if argmax_ids.dim() == 2 else argmax_ids[0]
                target_id   = targets[0, -1] if targets.dim() == 2 else targets[0]
                if last_id == target_id:
                    correct_bins += 1
                total_bins += 1
                next_val   = predictor.pipeline.tokenizer.output_transform(last_id, None)
                next_val   = float(next_val.detach().to("cpu"))
                forecasts.append(next_val)

                step = torch.tensor([[next_val]], dtype=torch.float32).cpu()  # [1,1]
                ctx_t = torch.cat([ctx_t, step], dim=1)
                # print(f"AR step: added {next_val:.3f}  → context now length {ctx_t.shape[1]}")

        truths = series_vals[L:L+H]
        return forecasts, truths, correct_bins, total_bins

    def batch_autoreg(series_batch):
        """
        series_batch: list of python lists of floats
        Returns: list of (forecast_list, truth_list, correct_bins, total_bins) per series (each with its own 70/30 split)
        """
        out = []
        for s in series_batch:
            f, t, c, tot = ar_one_series(
                s, device, predictor, pipeline,
                ctx_frac=args.ctx_frac, min_ctx=args.min_ctx, min_h=args.min_h
            )
            out.append((f, t, c, tot))
        return out

    # Run batched over my shard
    my_results = {}  # ip -> (forecasts, truth)
    with torch.inference_mode():
        for i in range(0, len(my_series), bs):
            chunk = my_series[i:i+bs]
            ips = [ip for ip, _ in chunk]
            sers = [s for _, s in chunk]
            print("inference batch:", i, "-", i+len(chunk)-1, f"({len(sers)} series)")
            out = batch_autoreg(sers)
            for ip, res in zip(ips, out):
                f, t, c, tot = res
                print(f"processed so far (rank {rank}): {len(my_results)+1}/{len(my_series)}  ip={ip}  forecast_len={len(f)} truth_len={len(t)} correct_bin percent={100*c/tot if tot>0 else 0:.1f}%")
                my_results[ip] = (f, t)

    # --------- Save per-rank shard ---------
    rank_out = os.path.join(args.save_pkl.split(".")[0], f"rank{rank:02d}.pkl")
    with open(rank_out, "wb") as f:
        pickle.dump(my_results, f)
    print(f"[rank {rank}] wrote {len(my_results)} series → {rank_out}")

    ddp_cleanup()


import argparse
import pickle
import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, ArrayType, FloatType
)



def get_representations_with_ip():
    """
    Extract representations specifically for IP-based data with IP information preserved.
    Useful when working with network data where IP identity is important.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("parquet_root", help="Root path to parquet data")
    parser.add_argument("--model", default="amazon/chronos-t5-small", 
                        help="Path to trained model directory")
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--ips_csv", default="IpsToConsider.csv",
                        help="CSV file with IPs to process")
    parser.add_argument("--output_csv", default="ip_representations.csv",
                        help="Output CSV file path")
    parser.add_argument("--nonhierarchical", action="store_true",
                        help="Use non-hierarchical data format")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load specific IPs if provided
    if os.path.exists(args.ips_csv):
        tmp_df = pd.read_csv(args.ips_csv)
        #create ip and source_file pair set
        allowed_ips = set(zip(tmp_df["ip"], tmp_df["source_file"]))
        print(f"Loaded {len(allowed_ips)} allowed IPs from {args.ips_csv}")
    else:
        allowed_ips = None
        print("No IP filter file found, processing all IPs")

    # Load data with IP information preserved
    # Use the hierarchical loading function but preserve IP info
    def load_with_ip_info(parquet_root, allowed_ips):
        spark = SparkSession.builder \
            .appName("IP-Representations-Load") \
            .config("spark.driver.memory", "200g") \
            .config("spark.driver.maxResultSize", "200g") \
            .getOrCreate()
        
        df = spark.read.option("recursiveFileLookup", "true").parquet(parquet_root)
        print(df.columns)
        # Filter by allowed IPs if provided the pair should be present ip ands source_file
        if allowed_ips:
            # Convert allowed_ips to DataFrame for proper join
            allowed_data = [(ip, source_file) for ip, source_file in allowed_ips]
            allowed_df = spark.createDataFrame(allowed_data, ["ip", "source_file"])
            df = df.join(allowed_df, on=["ip", "source_file"], how="inner")
        # Pick inbound/outbound column
        in_cols = [c for c in df.columns if "inbound" in c.lower()]
        out_cols = [c for c in df.columns if "outbound" in c.lower()]
        
        if in_cols:
            data_col = sorted(in_cols, key=len)[0]
        elif out_cols:
            data_col = sorted(out_cols, key=len)[0]
        else:
            raise ValueError("No inbound or outbound columns found")
        
        # Clean and filter data
        df_clean = df.select(
            "ip", "source_file",
            F.when(F.col(data_col).isNotNull(), 
                    F.expr(f"filter({data_col}, x -> x > 0.0)"))
            .otherwise(F.array().cast("array<double>")).alias("clean_series")
        ).filter(F.size("clean_series") >= 1)
        
        # Collect results
        results = df_clean.select("ip", "source_file", "clean_series").collect()
        spark.stop()

        return [(row["ip"], row["source_file"], row["clean_series"]) for row in results]

    ip_series_data = load_with_ip_info(args.parquet_root, allowed_ips)


    print(f"Loaded {len(ip_series_data)} IP-series pairs")

    # Load the trained model
    print("Loading trained model...")
    
    pipeline = MyChronosPipeline.from_pretrained(
        CHRONOSMODEL_PATH,
        device_map=None  # we'll pin manually
    )
    pipeline.model.to(device)
    pipeline.model.eval()

    import copy
    config = copy.deepcopy(pipeline.model.config)
    B = config.n_tokens - config.n_special_tokens
    del config

    with open(os.path.join(args.model, "boundaries.pkl"), "rb") as f:
        boundaries = pickle.load(f)

    state = torch.load(os.path.join(args.model, "chronos_best.pt"), map_location=device)
    state = strip_prefix_if_present(state, ["model.model.", "model."])
    predictor = ChronosBinPredictor(
        boundaries,
        CHRONOSMODEL_PATH,
        device
    ).to(device)

    print(type(predictor.pipeline.tokenizer))

    # load fine-tuned T5 weights (inner model only)
    state = torch.load(os.path.join(args.model, "chronos_best.pt"), map_location=device)
    state = strip_prefix_if_present(state, ["model.model.", "model."])
    predictor.model.model.load_state_dict(state, strict=True)
    predictor.model.eval()

    print("Model loaded successfully")

    # Process data to extract representations
    results_data = []
    tokenizer_state = None
    
    print("Extracting IP-based representations...")
    
    with torch.no_grad():
        for i in range(0, len(ip_series_data), args.batch_size):
            batch_data = ip_series_data[i:i+args.batch_size]
            batch_source_files = [sf for _, sf, _ in batch_data]
            batch_ips = [ip for ip,_,_ in batch_data]
            batch_series = [series for _, _, series in batch_data]
            
            try:
                # Convert to numpy array format as expected by pipeline.embed()
                batch_data_arrays = []
                for series in batch_series:
                    batch_data_arrays.append(torch.tensor(series, dtype=torch.float32))


                # Use pipeline.embed() method like in the reference implementation
                batch_emb, tokenizer_state = predictor.pipeline.embed(
                    batch_data_arrays,
                )
                
                # Extract the last token embedding for each series
                embeddings = batch_emb[:, -1, :].detach().cpu().numpy()  # [B, D]
                
                # Process each IP-series pair
                for j, (ip, series) in enumerate(zip(batch_ips, batch_series)):
                    series_repr = embeddings[j]  # [D]
                    
                    # Prepare result
                    result = {
                        'ip': ip,
                        'source_file': batch_source_files[j],
                        'series_id': i + j,
                        'series_length': len(series),
                        'n_tokens': batch_emb.shape[1],  # number of tokens
                    }
                    
                    # Add representation dimensions
                    for dim_idx, val in enumerate(series_repr):
                        result[f'repr_dim_{dim_idx}'] = val
                    
                    results_data.append(result)
                
                if (i // args.batch_size + 1) % 10 == 0:
                    print(f"Processed {i + len(batch_data)}/{len(ip_series_data)} IP-series pairs")
                    
            except Exception as e:
                print(f"Error processing batch {i//args.batch_size}: {e}")
                continue

    # Save results
    print("Saving IP representations to CSV...")
    
    df_results = pd.DataFrame(results_data)
    df_results.to_csv(args.output_csv, index=False)
    
    print(f"Saved {len(results_data)} IP representations to {args.output_csv}")
    print(f"DataFrame shape: {df_results.shape}")
    print(f"Unique IPs processed: {df_results['ip'].nunique()}")
    
    if 'repr_dim_0' in df_results.columns:
        repr_cols = [col for col in df_results.columns if col.startswith('repr_dim_')]
        print(f"Representation dimensions: {len(repr_cols)}")




if __name__=="__main__":
    import sys
    
    # Check command line arguments to determine which function to run
    if len(sys.argv) > 1 and sys.argv[1] == "--help":
        print("Available functions:")
        print("  python pretrain_chronos.py [function_name] [args...]")
        print("  Functions:")
        print("    main                     - Train the model")
        print("    inference_auto_regression - Auto-regressive inference")
        print("    get_representations_with_ip - Extract representations preserving IP information")
        sys.exit(0)
    
    # Default behavior - you can change this to your preferred default function
    if len(sys.argv) > 1:
        function_name = sys.argv[1]
        # Remove the function name from sys.argv so the argparse in each function works correctly
        sys.argv = [sys.argv[0]] + sys.argv[2:]
        
        if function_name == "main":
            main()
        elif function_name == "inference_auto_regression":
            inference_auto_regression()

        elif function_name == "get_representations_with_ip":
            get_representations_with_ip()
        else:
            print(f"Unknown function: {function_name}")
            print("Use --help to see available functions")
            sys.exit(1)
    else:
        # Default function when no arguments provided
        main()



