import argparse
import csv
import os
import time
from datetime import datetime
from typing import Dict, List, Tuple

import torch
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import os as _os

import pynvml  # type: ignore

_NVML_AVAILABLE = True


# Models
from opentslm.model.llm.OpenTSLMFlamingo import OpenTSLMFlamingo
from opentslm.model.llm.OpenTSLMSP import OpenTSLMSP

# Datasets
from opentslm.time_series_datasets.TSQADataset import TSQADataset
from opentslm.time_series_datasets.har_cot.HARCoTQADataset import HARCoTQADataset
from opentslm.time_series_datasets.sleep.SleepEDFCoTQADataset import SleepEDFCoTQADataset
from opentslm.time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset
from opentslm.time_series_datasets.simulation.SimulationQADataset import SimulationQADataset
from opentslm.time_series_datasets.util import (
    extend_time_series_to_match_patch_size_and_aggregate,
)


def get_device(device_arg: str | None) -> str:
    if device_arg:
        return device_arg
    return "cuda" if torch.cuda.is_available() else "cpu"


def measure_peak_cuda_bytes() -> int:
    if not torch.cuda.is_available():
        return -1
    torch.cuda.synchronize()
    return int(torch.cuda.max_memory_allocated())


def measure_peak_cuda_reserved_bytes() -> int:
    if not torch.cuda.is_available():
        return -1
    torch.cuda.synchronize()
    return int(torch.cuda.max_memory_reserved())


def nvml_current_process_bytes() -> int:
    if not _NVML_AVAILABLE or not torch.cuda.is_available():
        return -1
    try:
        pynvml.nvmlInit()
        pid = _os.getpid()
        total_bytes = 0
        found = False
        device_count = pynvml.nvmlDeviceGetCount()
        for device_index in range(device_count):
            handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
            # Try compute procs first
            try:
                procs = pynvml.nvmlDeviceGetComputeRunningProcesses_v3(handle)
            except Exception:
                procs = []
            # Fallback to graphics procs
            try:
                procs_gfx = pynvml.nvmlDeviceGetGraphicsRunningProcesses_v3(handle)
            except Exception:
                procs_gfx = []
            for p in list(procs) + list(procs_gfx):
                if int(p.pid) == pid and p.usedGpuMemory is not None and p.usedGpuMemory >= 0:
                    total_bytes += int(p.usedGpuMemory)
                    found = True
        return total_bytes if found else -1
    except Exception:
        return -1


def get_first_batch(dataset, batch_size: int = 1) -> List[Dict[str, any]]:
    # QADataset returns dict samples compatible with model.compute_loss
    batch: List[Dict[str, any]] = []
    for i in range(min(batch_size, len(dataset))):
        batch.append(dataset[i])
    # Ensure time series tensors are padded and converted
    batch = extend_time_series_to_match_patch_size_and_aggregate(batch)
    return batch


def build_optimizer(model, model_type: str, base_lr: float = 2e-4):
    if model_type == "OpenTSLMSP":
        enc_params = [p for p in getattr(model, "encoder").parameters() if p.requires_grad]
        proj_params = [p for p in getattr(model, "projector").parameters() if p.requires_grad]
        param_groups = []
        if len(enc_params) > 0:
            param_groups.append({"params": enc_params, "weight_decay": 0.1})
        if len(proj_params) > 0:
            param_groups.append({"params": proj_params, "weight_decay": 0.1})
        return torch.optim.AdamW(param_groups, lr=base_lr) if len(param_groups) > 0 else None
    # Flamingo-like
    named_params = list(model.named_parameters())
    trainable = list(
        filter(
            lambda np: np[1].requires_grad and not getattr(np[1], "exclude_from_optimizer", False),
            named_params,
        )
    )
    params_with_wd, params_without_wd = [], []
    for name, p in trainable:
        if "gated_cross_attn" in name:
            params_with_wd.append(p)
        else:
            params_without_wd.append(p)
    if len(params_with_wd) + len(params_without_wd) == 0:
        return None
    return torch.optim.AdamW(
        [
            {"params": params_with_wd, "weight_decay": 0.1},
            {"params": params_without_wd, "weight_decay": 0.0},
        ],
        lr=2e-4,
    )


def train_for_steps(model, model_type: str, dataset, steps: int) -> Tuple[float, int, int, int]:
    model.train()
    optimizer = build_optimizer(model, model_type)
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

    last_loss = 0.0
    # DataLoader with shuffle and collate that pads series
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=lambda b: extend_time_series_to_match_patch_size_and_aggregate(b),
        drop_last=False,
    )

    pbar = tqdm(total=steps, desc="Training", leave=False)
    max_peak_bytes = -1
    max_reserved_bytes = -1
    max_nvml_bytes = -1
    step = 0
    # Initialize postfix
    pbar.set_postfix(
        {
            "alloc_gb": 0.0,
            "res_gb": 0.0,
            "nvml_gb": 0.0,
        }
    )
    for batch in loader:
        if optimizer:
            optimizer.zero_grad(set_to_none=True)
        loss = model.compute_loss(batch)
        if optimizer and loss.requires_grad:
            print(f"Backpropagating loss of {loss.item()} for step {step}")
            loss.backward()
            optimizer.step()
        last_loss = float(loss.detach().item())
        # Track peak memory across steps
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            current_peak = int(torch.cuda.max_memory_allocated())
            current_reserved = int(torch.cuda.max_memory_reserved())
            if current_peak > max_peak_bytes:
                max_peak_bytes = current_peak
            if current_reserved > max_reserved_bytes:
                max_reserved_bytes = current_reserved
            nvml_bytes = nvml_current_process_bytes()
            if nvml_bytes > max_nvml_bytes:
                max_nvml_bytes = nvml_bytes

            # Update progress bar postfix in GB
            def _to_gb(val: int) -> float:
                return float(val) / (1024.0**3) if isinstance(val, (int, float)) and val >= 0 else 0.0

            pbar.set_postfix(
                {
                    "alloc_gb": f"{_to_gb(max_peak_bytes):.2f}",
                    "res_gb": f"{_to_gb(max_reserved_bytes):.2f}",
                    "nvml_gb": f"{_to_gb(max_nvml_bytes):.2f}",
                }
            )
        step += 1
        pbar.update(1)
        if step >= steps:
            break
    pbar.close()
    if torch.cuda.is_available():
        peak_bytes = max_peak_bytes
        peak_reserved_bytes = max_reserved_bytes
        nvml_peak_bytes = max_nvml_bytes
    else:
        peak_bytes = -1
        peak_reserved_bytes = -1
        nvml_peak_bytes = -1
    return last_loss, peak_bytes, peak_reserved_bytes, nvml_peak_bytes


def ensure_csv(path: str, header: List[str]):
    exists = os.path.exists(path)
    if not exists:
        with open(path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(header)


def append_row(path: str, row: List[any]):
    with open(path, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(row)


def run_for_dataset(model_name: str, model, dataset_name: str, dataset_obj) -> Dict[str, any]:
    result: Dict[str, any] = {
        "model": model_name,
        "dataset": dataset_name,
        "loss": None,
        "peak_cuda_bytes": None,
        "status": "ok",
        "error": "",
    }
    try:
        # Train for half an epoch, capped at 10000 steps
        steps = max(1, min(len(dataset_obj), 10000))
        loss, peak, peak_reserved, nvml_peak = train_for_steps(model, model_name, dataset_obj, steps)
        result["loss"] = loss
        result["peak_cuda_bytes"] = peak
        result["peak_cuda_reserved_bytes"] = peak_reserved
        result["nvml_peak_bytes"] = nvml_peak
    except Exception as e:
        result["status"] = "error"
        result["error"] = str(e)
    return result


def main():
    parser = argparse.ArgumentParser(
        description="Measure memory use for a single training iteration for a chosen model and dataset."
    )
    parser.add_argument("-llm_id", required=True, help="HuggingFace model id for the language model")
    parser.add_argument(
        "--model",
        required=True,
        choices=["OpenTSLMFlamingo", "OpenTSLMSP"],
        help="Model to instantiate",
    )
    parser.add_argument(
        "--dataset",
        required=True,
        choices=[
            "TSQADataset",
            "HARCoTQADataset",
            "SleepEDFCoTQADataset",
            "ECGQACoTQADataset",
            "SimulationQADataset",
        ],
        help="Dataset to use",
    )
    parser.add_argument("--device", default="cuda", help="Device to run on (e.g., cuda, cuda:0, cpu)")
    parser.add_argument(
        "--length",
        type=int,
        default=100,
        help="Length of time series for SimulationQADataset (default: 100)",
    )
    parser.add_argument(
        "--num_series",
        type=int,
        default=1,
        help="Number of time series for SimulationQADataset (default: 1)",
    )
    parser.add_argument(
        "--results_csv",
        default=os.path.join(REPO_DIR, "memory_use.csv"),
        help="Path to CSV file to append results",
    )
    args = parser.parse_args()

    device = get_device(args.device)

    # CSV header and file
    header = [
        "timestamp",
        "llm_id",
        "device",
        "model",
        "dataset",
        "loss",
        "peak_cuda_bytes",
        "peak_cuda_gb",
        "peak_cuda_reserved_bytes",
        "peak_cuda_reserved_gb",
        "nvml_peak_bytes",
        "nvml_peak_gb",
        "status",
        "error",
    ]
    ensure_csv(args.results_csv, header)

    # Instantiate selected model
    if args.model == "OpenTSLMFlamingo":
        model = OpenTSLMFlamingo(
            device=device,
            llm_id=args.llm_id,
            cross_attn_every_n_layers=1,
            gradient_checkpointing=True,
        )
        eos = model.get_eos_token()
    elif args.model == "OpenTSLMSP":
        model = OpenTSLMSP(llm_id=args.llm_id, device=device)
        eos = model.get_eos_token()
    else:
        raise ValueError(f"Unknown model: {args.model}")

    # Make absolutely sure parameters are on the requested device
    model.to(device)

    # Instantiate selected dataset
    if args.dataset == "TSQADataset":
        dataset = TSQADataset(split="train", EOS_TOKEN=eos)
        dataset_name = "TSQA"
    elif args.dataset == "HARCoTQADataset":
        dataset = HARCoTQADataset(split="train", EOS_TOKEN=eos)
        dataset_name = "HAR-CoT"
    elif args.dataset == "SleepEDFCoTQADataset":
        dataset = SleepEDFCoTQADataset(split="train", EOS_TOKEN=eos)
        dataset_name = "SleepEDF-CoT"
    elif args.dataset == "ECGQACoTQADataset":
        dataset = ECGQACoTQADataset(split="train", EOS_TOKEN=eos, max_samples=1, preload_processed_data=False)
        dataset_name = "ECG-QA-CoT"
    elif args.dataset == "SimulationQADataset":
        dataset = SimulationQADataset(split="train", EOS_TOKEN=eos, length=args.length, num_series=args.num_series)
        dataset_name = f"Simulation-L{args.length}-N{args.num_series}"
    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")

    # Run one iteration and append results
    res = run_for_dataset(args.model, model, dataset_name, dataset)
    peak_bytes = res["peak_cuda_bytes"]
    peak_gb = (float(peak_bytes) / (1024.0**3)) if isinstance(peak_bytes, (int, float)) and peak_bytes >= 0 else -1
    peak_reserved_bytes = res.get("peak_cuda_reserved_bytes", -1)
    peak_reserved_gb = (
        (float(peak_reserved_bytes) / (1024.0**3))
        if isinstance(peak_reserved_bytes, (int, float)) and peak_reserved_bytes >= 0
        else -1
    )
    nvml_peak_bytes = res.get("nvml_peak_bytes", -1)
    nvml_peak_gb = (
        (float(nvml_peak_bytes) / (1024.0**3))
        if isinstance(nvml_peak_bytes, (int, float)) and nvml_peak_bytes >= 0
        else -1
    )
    append_row(
        args.results_csv,
        [
            datetime.utcnow().isoformat(),
            args.llm_id,
            device,
            res["model"],
            res["dataset"],
            res["loss"],
            res["peak_cuda_bytes"],
            (f"{peak_gb:.4f}" if isinstance(peak_gb, float) and peak_gb >= 0 else peak_gb),
            res.get("peak_cuda_reserved_bytes", -1),
            (
                f"{peak_reserved_gb:.4f}"
                if isinstance(peak_reserved_gb, float) and peak_reserved_gb >= 0
                else peak_reserved_gb
            ),
            res.get("nvml_peak_bytes", -1),
            (f"{nvml_peak_gb:.4f}" if isinstance(nvml_peak_gb, float) and nvml_peak_gb >= 0 else nvml_peak_gb),
            res["status"],
            res["error"],
        ],
    )

    print(f"Done. Results appended to: {args.results_csv}")


if __name__ == "__main__":
    main()
