from torchvision import datasets
from torch.utils.data import DataLoader
from utils.utils import TensorDataset
from .sampling import *

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Dataset

import numpy as np
from torchvision import datasets


class ImbalancedMNIST(datasets.MNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False, imb_type='lt', imb_factor=0.1):
        super(ImbalancedMNIST, self).__init__(root, train=train, transform=transform,
                                              target_transform=target_transform, download=download)
        self.cls_num = 10  # MNIST has 10 classes
        img_num_per_cls = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
        self.gen_imbalanced_data(img_num_per_cls)

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'lt':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)
        self.num_per_cls_dict = dict()

        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx])
            new_targets.extend([the_class] * the_img_num)

        self.data = torch.tensor(np.vstack(new_data), dtype=torch.uint8)
        self.targets = new_targets
        print(f"Imbalanced MNIST dataset created. Class distribution: {self.num_per_cls_dict}")


class DataLoader():
    def __init__(self, simul_type = 't1', train_N = 100000, val_N = 100000, test_N = 100000, K = 2, seed=1, batch_size = 1024 , device='cpu'):
        self.device = device
        self.seed = seed
        self.train_N = train_N
        self.val_N = val_N
        self.test_N = test_N
        self.K = K
        self.batch_size = batch_size
        self.n_dim = 1
        self.simul_type = simul_type

        
    def load_sampler(self,K, ratio_list, sample_mu_list, sample_var_list, sample_nu_list):
        # Generate dataset
        sampler = select_sampler(self.simul_type, self.device, self.seed)


        origin_mu_list = [sample_mu_list[0]] * K
        origin_var_list = [sample_var_list[0]] * K
        origin_nu_list = [sample_nu_list[0]] * K


        data = sampler.sample_generation(
        K=self.K, N=self.train_N+self.val_N, ratio_list = ratio_list, mu_list=origin_mu_list, var_list=origin_var_list, nu_list=origin_nu_list
        )
        generator = torch.Generator().manual_seed(self.seed)
        torch.utils.data.random_split(data, [0.6,0.4],generator=generator)

        train_data, validation_data = torch.utils.data.random_split(data, [0.6,0.4],generator=generator)
        test_data = sampler.sample_generation(
        K=self.K, N=self.test_N, ratio_list = ratio_list, mu_list=sample_mu_list, var_list=sample_var_list, nu_list=sample_nu_list
        )

        print(len(train_data), len(validation_data), len(test_data))
        n_dim = len(train_data[0])
        train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.batch_size)
        val_loader = torch.utils.data.DataLoader(validation_data, batch_size=self.batch_size)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=self.batch_size)

        n_dim = len(train_data[0])
        return n_dim, train_loader, val_loader, test_loader

    def load_mnist(self):
        transform = transforms.Compose([
            transforms.Pad(2),
            transforms.ToTensor(),
        ])

        # dataset = ImbalancedMNIST(root='../data', train=True, transform=transform,
                                #  imb_type='lt', download=True,imb_factor=0.01)

        dataset = datasets.MNIST(root='../data', train=True, transform=transform, download=True)
                          
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        test_dataset = datasets.MNIST(root='../data', train=False, transform=transform, download=True)
        

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

        n_dim = 32 * 32 # img size with padding
        
        return n_dim, train_loader, val_loader, test_loader
    
    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'lt':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)
        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([the_class, ] * the_img_num)
        
        new_data = np.vstack(new_data)
        print("the number of data:", len(new_data))
        self.data = new_data
        self.targets = new_targets

    def load_data(self, normalized = True):
        data =[]
        with open('./data/soc-Ep.txt', 'r', encoding='utf-8') as f:
            for line in f:
                temp = list(map(float,line.strip().split('\t')))
                data.append(temp)
        f.close()
        data = torch.tensor(data)
        src = data[:, 0].long()  # FromNodeId
        dst = data[:, 1].long()  # ToNodeId

        num_nodes = int(torch.max(data)) + 1  # 전체 노드 수 추정

        # out-degree: 출발지 기준으로 카운트
        out_degree = torch.bincount(src, minlength=num_nodes)
        in_degree = torch.bincount(dst, minlength=num_nodes)
        # in-degree: 도착지 기준으로 카운트
        # in_degree = torch.bincount(dst, minlength=num_nodes)
        # data = torch.stack([in_degree, out_degree], dim=1).float()
        data = torch.stack([in_degree.float(), out_degree.float()], dim=1)  # (N, 2)

        if normalized:
            data = (data-torch.mean(data)) / (torch.std(data)+1e-6)
        N = data.shape[0]
        generator = torch.Generator().manual_seed(self.seed)
        train_data, validation_data, test_data = torch.utils.data.random_split(data, [0.6,0.2,0.2],generator=generator)
        print(len(train_data), len(validation_data), len(test_data))
        n_dim = len(train_data[0])
        train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.batch_size,generator=generator)
        val_loader = torch.utils.data.DataLoader(validation_data, batch_size=self.batch_size,generator=generator)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=self.batch_size,generator=generator)

        return n_dim, train_loader, val_loader, test_loader
    
    def load_spike(self,
               N=1_000_000,
               pi=0.8,              # mix prob for +mu component
               spike_prob=0.5,      # point mass at mode
               alpha=1.5,           # tail (shape) > 0
               scale_pos=1.0,       # scale for +mu component
               scale_neg=2.0,       # scale for -mu component
               mu=0.0):
        """
        X ~ mixture of two *symmetric* Pareto components centered at ±mu,
            each with its own scale (scale_pos for +mu, scale_neg for -mu),
            plus a point-mass spike at its mode with probability spike_prob.

        Symmetric Pareto (Lomax-symmetric) sampler:
            R ~ Lomax(sigma, alpha) with CDF 1 - (1 + r/sigma)^(-alpha), r>=0
            sign S ∈ {-1, +1} w.p. 1/2
            X = loc + S * R
        """
        import torch
        from torch.utils.data import DataLoader, random_split

        assert alpha > 0, "alpha (shape) must be > 0"
        device = self.device
        m = 1  # 1D

        # 1) Choose component: +mu (Z=1) or -mu (Z=0)
        Z = (torch.rand(N, 1, device=device) < pi).float()
        loc = (2 * Z - 1) * mu  # Z=1 -> +mu, Z=0 -> -mu

        # 2) Decide spike vs continuous within chosen component
        S_spike = (torch.rand(N, 1, device=device) < spike_prob).float()

        # 3) Sample from symmetric Pareto around each loc
        #    - component-specific scales
        scale_pos_t = torch.full((N, 1), float(scale_pos), device=device)
        scale_neg_t = torch.full((N, 1), float(scale_neg), device=device)
        scales = torch.where(Z.bool(), scale_pos_t, scale_neg_t)  # shape [N,1]

        #    - Lomax radial part: R = sigma * ((1-U)^(-1/alpha) - 1), U~Uniform(0,1)
        U = torch.rand(N, 1, device=device).clamp_(1e-7, 1 - 1e-7)
        R = scales * ( (1 - U).pow(-1.0 / float(alpha)) - 1.0 )

        #    - random sign for symmetry -> Fix 1 to sample positive Pareto
        sign = 1

        X_cont = loc + sign * R

        # 4) Mix: spike at mode OR continuous symmetric-Pareto around the same mode
        data = torch.where(S_spike.bool(), loc, X_cont)  # [N,1]

        # 5) Split & loaders
        generator = torch.Generator(device='cpu').manual_seed(self.seed)

        N = data.shape[0]
        n_train = int(N * 0.5)
        n_val   = int(N * 0.2)
        n_test  = N - n_train - n_val

        train_data, validation_data, test_data = torch.utils.data.random_split(
            data, [n_train, n_val, n_test], generator=generator
        )

        n_dim = len(train_data[0])  # = 1
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        val_loader   = DataLoader(validation_data, batch_size=self.batch_size, shuffle=False)
        test_loader  = DataLoader(test_data, batch_size=self.batch_size, shuffle=False)
        return n_dim, train_loader, val_loader, test_loader
    
    def load_manifolds(self, N=60000):
        data = ParetoManifold20D(N=N)
        classes, counts = torch.unique(data.labels, return_counts=True)
        print('classes:', classes.tolist(), 'counts:', counts.tolist())
        # 5) Split & loaders
        generator = torch.Generator(device='cpu').manual_seed(self.seed)

        n_train = 50000
        n_val   = 5000
        n_test  = 5000

        train_data, validation_data, test_data = torch.utils.data.random_split(
            data, [n_train, n_val, n_test], generator=generator
        )

        n_dim = 10
        train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        val_loader   = torch.utils.data.DataLoader(validation_data, batch_size=self.batch_size, shuffle=False)
        test_loader  = torch.utils.data.DataLoader(test_data, batch_size=self.batch_size, shuffle=False)
        return n_dim, train_loader, val_loader, test_loader


def compute_radius(Z: torch.Tensor, norm='l2'):
    if norm == 'l2':
        return torch.linalg.norm(Z, ord=2, dim=1)
    elif norm == 'l1':
        return torch.linalg.norm(Z, ord=1, dim=1)
    elif norm in ('linf', 'l∞', 'inf'):
        return torch.linalg.norm(Z, ord=float('inf'), dim=1)
    else:
        raise ValueError(f"unknown norm: {norm}")

def radial_labels(
    Z: torch.Tensor,
    mode='quantile',            # 'quantile' | 'fixed' | 'geom'
    norm='l2',
    percentiles=(0.5, 0.9, 0.99),
    edges=None,                 # mode='fixed'일 때 반지름 경계 리스트
    geom_start=0.0,             # mode='geom'일 때 시작 반지름 r0
    geom_ratio=1.8,             # 기하등비 간격 비율 (>1)
    K=None,                     # 만들 클래스 수(= 경계+1)
):
    """
    반환:
      y: [N] long tensor (0..C-1)
      info: dict { 'edges': np.ndarray, 'norm': norm, 'mode': mode }
    """
    r = compute_radius(Z, norm=norm)  # [N]
    r_cpu = r.detach().cpu().numpy()

    if mode == 'quantile':
        # eg. (0.5, 0.9, 0.99) -> 4 classes: core / shoulder / tail / extreme
        q = np.asarray(percentiles, dtype=float)
        q = np.clip(q, 0.0, 1.0)
        edges = np.quantile(r_cpu, q)
        print(edges)
    elif mode == 'fixed':
        assert edges is not None and len(edges) >= 1
        edges = np.asarray(edges, dtype=float)
    elif mode == 'geom':
        assert K is not None and K >= 2
        edges = [geom_start * (geom_ratio ** i) for i in range(1, K)]
        edges = np.asarray(edges, dtype=float)
    else:
        raise ValueError(f"unknown mode: {mode}")

    # torch.bucketize로 0..len(edges) 클래스 인덱스
    edges_t = torch.as_tensor(edges, dtype=r.dtype, device=r.device)
    y = torch.bucketize(r, edges_t, right=False)  # [N]
    return y.to(torch.long), {'edges': edges, 'norm': norm, 'mode': mode}

# --------- 1) 대칭 Pareto(=Lomax) 1D 샘플러 ----------
def sample_symmetric_lomax(n, alpha=1.5, scale=1.0, device="cpu"):
    # Lomax(|X|): P(R>r) = (1 + r/scale)^(-alpha), r>=0
    u = torch.rand(n, device=device).clamp_(1e-12, 1-1e-12)
    r = scale * (u.pow(-1.0/alpha) - 1.0)  # Lomax radius
    s = torch.where(torch.rand(n, device=device) < 0.5, -1.0, 1.0)
    # s = 1 # Fix 1 to sample positive Pareto.
    return s * r

def sample_pareto2d(N, alpha=(6.0, 6.0), scale=(1.0, 1.0), device="cpu", seed=0):
    g = torch.Generator(device=device).manual_seed(seed)
    torch.manual_seed(seed)
    z1 = sample_symmetric_lomax(N, alpha=alpha[0], scale=scale[0], device=device)
    z2 = sample_symmetric_lomax(N, alpha=alpha[1], scale=scale[1], device=device)
    Z = torch.stack([z1, z2], dim=1)  # [N,2]
    return Z

# --------- 2) 고정 비선형 임베딩 F: R^2 -> R^20 ----------
class FixedMLPEmbedding(torch.nn.Module):
    """
    F(z) = concat( z,  g(z) ), g: R^2 -> R^20 fixed random MLP (tanh).
    -> 2D 그래프 임베딩(전역 주입). 20D manifold 확보.
    """
    def __init__(self, hidden=16, out_dim=10, seed=0, scale_in=0.5):
        super().__init__()
        torch.manual_seed(seed)
        self.scale_in = scale_in
        self.W1 = torch.nn.Linear(2, hidden, bias=True)
        self.W2 = torch.nn.Linear(hidden, out_dim, bias=True)
        # 고정 파라미터화
        for p in self.parameters():
            torch.nn.init.orthogonal_(p) if p.dim() > 1 else torch.nn.init.uniform_(p, -0.1, 0.1)
            p.requires_grad_(False)

    @torch.no_grad()
    def forward(self, z):
        h = torch.tanh(self.W1(self.scale_in * z))
        g = torch.tanh(self.W2(h))
        return g  # [N, 20]

# --------- 3) PyTorch Dataset ----------
class ParetoManifold20D(Dataset):
    def __init__(self, N=100_000, alpha=(1.8,1.8), scale=(1.0,1.0),
                 noise_std=0.0,  # 소량 오프-매니폴드 잡음(선택)
                 seed=0, device="cpu"):
        self.device = device
        Z = sample_pareto2d(N, alpha=alpha, scale=scale, device=device, seed=seed)
        self.embed = FixedMLPEmbedding(seed=seed)
        with torch.no_grad():
            X = self.embed(Z)
            if noise_std > 0:
                X = X + noise_std * torch.randn_like(X)
        self.X = X.cpu()
        self.Z = Z.cpu()  # 진짜 2D 원인 (진단/시각화용)
        
        # 예) 분위수 기반 4클래스(50/90/99%)
        self.labels, self.label_info = radial_labels(
            Z, mode='quantile', norm='l2', percentiles=(0.5, 0.9, 0.99)
        )

    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i):
        if self.labels is None:
            return self.X[i], self.Z[i]
        else:
            return self.X[i], self.Z[i], self.labels[i]
