import os
import pickle
import argparse
from typing import List, Sequence, Union

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from ptflops import get_model_complexity_info


def prepare_features(text_data, img_data, text_agg: str = "mean"):
    norm_txt = {fn[-10:-4]: fn for fn in text_data}
    norm_img = {fn[-10:-4]: fn for fn in img_data}
    common_keys = set(norm_txt) & set(norm_img)

    txt_keys = [norm_txt[k] for k in common_keys]
    img_keys = [norm_img[k] for k in common_keys]

    txt_feats, img_feats = [], []

    for tkey, ikey in zip(txt_keys, img_keys):
        hs = text_data[tkey]["last_hidden_state"]  # shape: [1, seq_len, l_dim]
        if text_agg == "mean":
            vec = hs.mean(dim=1)  # shape: [1, l_dim]
        elif text_agg == "max":
            vec, _ = hs.max(dim=1)
        else:
            raise ValueError(f"Unknown text_agg: {text_agg}")
        txt_feats.append(vec.squeeze(0).cpu().numpy())

        cls_vec = img_data[ikey].squeeze()[0]  # CLS representation
        img_feats.append(cls_vec)

    return np.stack(txt_feats), np.stack(img_feats)


class EmbeddingDataset(Dataset):
    def __init__(self, img_feats, pos_txt_feats, neg_txt_feats):
        self.img   = torch.from_numpy(img_feats).float()
        self.pos_t = torch.from_numpy(pos_txt_feats).float()
        self.neg_t = torch.from_numpy(neg_txt_feats).float()

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        return self.img[idx], self.pos_t[idx], self.neg_t[idx]


def load_incremental_pickle(pkl_path: str) -> dict:
    result = {}
    with open(pkl_path, "rb") as f:
        while True:
            try:
                entry = pickle.load(f)
                result.update(entry)
            except EOFError:
                break
    return result


def load_data(
    image_path: str,
    positive_text_path: str,
    negative_text_path: str,
    batch_size: int = 32,
    test_ratio: float = 0.1,
    val_ratio: float = 0.2,
    shuffle_seed: int = 55,
    text_agg: str = "mean",
):
    with open(image_path, "rb") as f:
        img_data = pickle.load(f)

    pos_text = load_incremental_pickle(positive_text_path)
    neg_text = load_incremental_pickle(negative_text_path)

    t_pos, i_feats = prepare_features(pos_text, img_data, text_agg=text_agg)
    t_neg, _ = prepare_features(neg_text, img_data, text_agg=text_agg)

    assert len(t_pos) == len(i_feats) == len(t_neg), "Mismatch in dataset lengths"

    dataset = EmbeddingDataset(i_feats, t_pos, t_neg)
    n = len(dataset)

    assert test_ratio + val_ratio < 1.0, "Sum of test_ratio and val_ratio must be < 1"
    test_size = int(n * test_ratio)
    val_size  = int(n * val_ratio)
    train_size = n - val_size - test_size

    generator = torch.Generator().manual_seed(shuffle_seed) if shuffle_seed is not None else None
    train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size], generator=generator)

    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds, batch_size=batch_size, shuffle=False),
        DataLoader(test_ds, batch_size=batch_size, shuffle=False),
    )


def calculate_flops(model: nn.Module, input_shape: Union[tuple, List[int]]):
    """
    Estimates FLOPs and parameters using ptflops for a model.

    Args:
        model: PyTorch model
        input_shape: Input tensor shape (excluding batch size)

    Returns:
        Tuple: (flops, params) as strings
    """
    import copy
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_copy = copy.deepcopy(model).to(device).eval()

    def replace_bn_with_identity(module):
        for name, child in module.named_children():
            if isinstance(child, nn.BatchNorm1d):
                setattr(module, name, nn.Identity())
            else:
                replace_bn_with_identity(child)

    replace_bn_with_identity(model_copy)

    try:
        with torch.no_grad():
            flops, params = get_model_complexity_info(
                model_copy, input_shape, as_strings=True,
                print_per_layer_stat=False, verbose=False
            )
    except Exception as e:
        print(f"[FLOPs Error] {e}")
        flops, params = "N/A", "N/A"

    return flops, params
