import torch
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset, Subset
from torchvision import transforms

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

import os, sys
import numpy as np
import random

class Load_Dataset(Dataset):
    def __init__(self, dataset, dataset_configs):
        # super().__init__()
        self.num_channels = dataset_configs.input_channels

        # Load samples
        x_data = dataset["samples"]

        # Load labels
        y_data = dataset.get("labels")
        if y_data is not None and isinstance(y_data, np.ndarray):
            y_data = torch.from_numpy(y_data)
        
        # Convert to torch tensor
        if isinstance(x_data, np.ndarray):
            x_data = torch.from_numpy(x_data)
        
        # Check samples dimensions.
        # The dimension of the data is expected to be (N, C, L)
        # where N is the #samples, C: #channels, and L is the sequence length
        if len(x_data.shape) == 2:
            x_data = x_data.unsqueeze(1)
        elif len(x_data.shape) == 3 and x_data.shape[1] != self.num_channels:
            x_data = x_data.transpose(1, 2)

        # Normalize data
        if dataset_configs.normalize:
            data_mean = torch.mean(x_data, dim=(0, 2))
            data_std = torch.std(x_data, dim=(0, 2))
            self.transform = transforms.Normalize(mean=data_mean, std=data_std)
        else:
            self.transform = None
        self.x_data = x_data.float()
        self.y_data = y_data.long() if y_data is not None else None
        self.len = x_data.shape[0]
        
    def __getitem__(self, index):
        x = self.x_data[index]
        if self.transform:
            x = self.transform(self.x_data[index].reshape(self.num_channels, -1, 1)).reshape(self.x_data[index].shape)
        y = self.y_data[index] if self.y_data is not None else None
        return x, y
        
    def __len__(self):
        return self.len



class CustomDataset(Load_Dataset):
    def __init__(self, dataset, dataset_configs, unlabeled_ratio = 0.9):
        super().__init__(dataset, dataset_configs)
        N = self.len
        n_unlbl = int(N * unlabeled_ratio)
        self.unlbl_idx = random.sample(range(N), k=n_unlbl)
        self.lbl_idx = list(set(range(N))-set(self.unlbl_idx))

    def __getitem__(self, idx):
        if idx < len(self.lbl_idx):
            # labeled data
            temp_idx = self.lbl_idx[idx]
            x = self.x_data[temp_idx]
            if self.transform:
                x = self.transform(self.x_data[temp_idx].reshape(self.num_channels,-1,1)).reshape(self.x_data[temp_idx].shape)
            y = self.y_data[temp_idx] if self.y_data is not None else None
            return x, y, 1 # labeled
        else:
            # unlabeled data
            temp_idx = idx - len(self.lbl_idx)
            data_idx = self.unlbl_idx[temp_idx]
            x = self.x_data[data_idx]
            if self.transform:
                x = self.transform(self.x_data[data_idx].reshape(self.num_channels,-1,1)).reshape(self.x_data[data_idx].shape)
            y = self.y_data[temp_idx] if self.y_data is not None else None
            return x, y, 0 # unlabeled

    def get_weights(self):
        class_sample_count = np.bincount(self.y_data)
        weight = 1./class_sample_count
        samples_weight = torch.tensor([weight[t] for t in self.y_data], dtype=torch.double)
        return samples_weight

class Aug_Dataset(Load_Dataset):
    def __init__(self, dataset, dataset_configs, unlabeled_ratio=0.9, transform=None, strong_transform=None, rot=False, test=None):
        """
        dataset : samples, labels
        transform : standard transform
        strong_transform : strong_augmentation
        rot : In PAC, originally rotation augmentation are applied for pretext task with image data.
             However, rotation augmentation is not suitable for Time series. So, we apply time-series augmentations.
        """
        super().__init__(dataset, dataset_configs)
        N = self.len
        n_unlbl = int(N*unlabeled_ratio)
        self.unlbl_idx = random.sample(range(N), k=n_unlbl)
        self.lbl_idx = list(set(range(N))-set(self.unlbl_idx))
        self.transform = transform
        self.strong_transform = strong_transform
        self.rot = rot
        self.test = test
        self.num_channels = dataset_configs.input_channels

        # Load samples
        x_data = dataset["samples"]

        # Load labels
        y_data = dataset.get("labels")
        if y_data is not None and isinstance(y_data, np.ndarray):
            y_data = torch.from_numpy(y_data)
        
        # Convert to torch tensor
        if isinstance(x_data, np.ndarray):
            x_data = torch.from_numpy(x_data)
        
        # Check samples dimensions.
        # The dimension of the data is expected to be (N, C, L)
        # where N is the #samples, C: #channels, and L is the sequence length
        if len(x_data.shape) == 2:
            x_data = x_data.unsqueeze(1)
        elif len(x_data.shape) == 3 and x_data.shape[1] != self.num_channels:
            x_data = x_data.transpose(1, 2)

       # Normalize data
        if dataset_configs.normalize:
            data_mean = torch.mean(x_data, dim=(0,2))
            data_std = torch.std(x_data, dim=(0,2))
            self.normalize = transforms.Normalize(mean=data_mean, std=data_std)
        else:
            self.normalize = None

    def __getitem__(self, idx):
        if idx < len(self.lbl_idx):
            temp_idx = self.lbl_idx[idx]
            is_labeled = 1
        else:
            lbl_idx = idx - len(self.lbl_idx)
            temp_idx = self.unlbl_idx[lbl_idx]
            is_labeled = 0

        x = self.x_data[temp_idx] # (C, T)
        if self.normalize:
            x = self.normalize(self.x_data[temp_idx].reshape(self.num_channels, -1, 1)).reshape(self.x_data[temp_idx].shape)
            
        y = self.y_data[temp_idx] if self.y_data is not None else None

        # standard transform
        if self.transform is not None and not self.rot:
            x = self._apply_transform(x)
            
        # strong augmentation
        x_bar = x_bar2 = None
        if self.strong_transform is not None:
            x_bar = self.strong_transform(x)
            x_bar2 = self.strong_transform(x)
            
        if self.rot:
            def add_jitter(signal, scale=0.03):
                return signal + torch.randn_like(signal) * scale
            all_pretext_x = [
                x,
                x.flip(-1),
                add_jitter(x, scale=0.03),
                add_jitter(x.flip(-1), scale=0.03)
                ]
            all_pretext_x = torch.stack(all_pretext_x, dim=0)
            rot_target = torch.LongTensor([0,1,2,3])

            return x, x_bar, y, is_labeled, all_pretext_x, rot_target
            
        return x, x_bar, x_bar2, y, is_labeled

    def get_weights(self):
        class_sample_count = np.bincount(self.y_data)
        weight = 1./class_sample_count
        samples_weight = torch.tensor([weight[t] for t in self.y_data], dtype=torch.double)
        return samples_weight
        
    def _apply_transform(self, x):
        """ Applying a transform while preserving the time series format """
        return self.transform(x)



def data_generator(data_path, domain_id, dataset_configs, hparams, dtype, domain='source', unlabeled_ratio=0.9, sampler = True): 
    # loading dataset file from path
    dataset_file = torch.load(os.path.join(data_path, f"{dtype}_{domain_id}.pt"))

    shuffle = False if dtype == "test" or domain == 'source' else dataset_configs.shuffle
    drop_last = False if dtype == "test" or domain == 'source' else dataset_configs.drop_last
    unlabeled_ratio = 0 if dtype == "test" or domain == 'source' else unlabeled_ratio
    sampler = False if domain == "target" or dtype == "test" else True

    # Loading datasets using CustomDataset
    dataset = CustomDataset(dataset_file, dataset_configs, unlabeled_ratio)
    
    if sampler == True :
        samples_weight = dataset.get_weights()
        data_sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
        data_loader = torch.utils.data.DataLoader(dataset= dataset,
                                                  batch_size = hparams["batch_size"],
                                                  drop_last = drop_last,
                                                  num_workers = 8,
                                                  sampler = data_sampler)
    else:
        # Dataloaders
        data_loader = torch.utils.data.DataLoader(dataset=dataset, 
                                                  batch_size=hparams["batch_size"],
                                                  shuffle=shuffle, 
                                                  drop_last=drop_last, 
                                                  num_workers=8)

    return data_loader

def data_generator_baseline(data_path, domain_id, dataset_configs, hparams, dtype, domain='source', transform=None, strong_transform=None, rot=False, test=None, unlabeled_ratio=0.9, sampler=False):
    # Loading dataset file from path
    dataset_file = torch.load(os.path.join(data_path, f"{dtype}_{domain_id}.pt"))

    shuffle = False if dtype == 'test' or domain == 'source' else dataset_configs.shuffle
    drop_last = False if dtype == 'test' else True
    unlabeled_ratio = 0 if dtype == 'test' or domain == 'source' else unlabeled_ratio
    sampler = False if dtype == 'test' or domain == 'target' else True

    dataset = Aug_Dataset(dataset_file, dataset_configs, unlabeled_ratio, transform, strong_transform, rot, test)

    if sampler == True:
        samples_weight = dataset.get_weights()
        data_sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
        data_loader = torch.utils.data.DataLoader(dataset = dataset,
                                                 batch_size = hparams['batch_size'],
                                                 drop_last = drop_last,
                                                 num_workers= 8,
                                                 sampler = data_sampler)
    else:
        data_loader = torch.utils.data.DataLoader(dataset = dataset,
                                                 batch_size = hparams['batch_size'],
                                                 drop_last = drop_last,
                                                 num_workers = 8)
    return data_loader
    
def few_shot_data_generator(data_loader, dataset_configs, num_samples=5):
    
    # Get the original dataset from Subset if necessary
    dataset = data_loader.dataset
    while isinstance(dataset, torch.utils.data.Subset):  # Ensure we get the base dataset
        dataset = dataset.dataset

    # Now access x_data and y_data
    x_data = dataset.x_data
    y_data = dataset.y_data
    
    NUM_SAMPLES_PER_CLASS = num_samples
    NUM_CLASSES = len(torch.unique(y_data))

    counts = [y_data.eq(i).sum().item() for i in range(NUM_CLASSES)]
    samples_count_dict = {i: min(counts[i], NUM_SAMPLES_PER_CLASS) for i in range(NUM_CLASSES)}

    samples_ids = {i: torch.where(y_data == i)[0] for i in range(NUM_CLASSES)}
    selected_ids = {i: torch.randperm(samples_ids[i].size(0))[:samples_count_dict[i]] for i in range(NUM_CLASSES)}

    selected_x = torch.cat([x_data[samples_ids[i][selected_ids[i]]] for i in range(NUM_CLASSES)], dim=0)
    selected_y = torch.cat([y_data[samples_ids[i][selected_ids[i]]] for i in range(NUM_CLASSES)], dim=0)

    few_shot_dataset = {"samples": selected_x, "labels": selected_y}
    few_shot_dataset = CustomDataset(few_shot_dataset, dataset_configs)

    few_shot_loader = torch.utils.data.DataLoader(dataset=few_shot_dataset, batch_size=len(few_shot_dataset),
                                                  shuffle=False, drop_last=False, num_workers=8)

    return few_shot_loader

def data_generator_bound(data_path, domain_id, dataset_configs, hparams, dtype, domain='source', unlabeled_ratio=0.9, sampler = True):
    # loading dataset file from path
    dataset_file = torch.load(os.path.join(data_path, f"{dtype}_{domain_id}.pt"))
    
    shuffle = False if dtype == "test" or domain == 'source' else dataset_configs.shuffle
    drop_last = False if dtype == "test" or domain == 'source' else dataset_configs.drop_last
    unlabeled_ratio = 0 if dtype == "test" or domain == 'source' else unlabeled_ratio
    sampler = False if domain == "target" or dtype == "test" else True

    # Loading datasets using CustomDataset
    dataset = CustomDataset(dataset_file, dataset_configs, unlabeled_ratio)
    labeled_indices = [i for i in range(len(dataset)) if dataset[i][2] == 1] 

    # Subset dataset
    subset_dataset = Subset(dataset, labeled_indices)

    
    if sampler == True :
        samples_weight = dataset.get_weights()
        data_sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
        data_loader = torch.utils.data.DataLoader(dataset= subset_dataset,
                                                  batch_size = hparams["batch_size"],
                                                  drop_last = drop_last,
                                                  num_workers = 8,
                                                  sampler = data_sampler)
    else:
        # Dataloaders
        data_loader = torch.utils.data.DataLoader(dataset=subset_dataset, 
                                                  batch_size=hparams["batch_size"],
                                                  shuffle=shuffle, 
                                                  drop_last=drop_last, 
                                                  num_workers=8)

    return data_loader


####################################################################################

class TSRandomHorizontalFlip:
    def __call__(self, ts):
        if random.random() < 0.5:
            return ts.flip(-1)
        return ts

class RandomErasingTS:
    """RAndom TS masking"""
    def __init__(self, p=0.5, scale=(0.02,0.33)):
        self.p = p
        self.scale = scale
    def __call__(self, ts):
        if random.random() > self.p:
            return ts
        B,C,T = ts.shape
        erase_len = int(random.uniform(*self.scale)*T)
        ts_erased = ts.clone()
        for b in range(B):
            start = random.randint(0, T-erase_len)
            ts_erased[b,:,start:start+erase_len] = 0
        return ts_erased
        
class RandAugmentTS:
    """Time Seires RandAugment"""
    def __init__(self, n=2, m=9):
        self.n = n
        self.m = m
        self.augment_pool = [
            (AddNoise, 0.3, 0),
            (Scale, 1.8, 0.1),
            (TimeWarp, 0.3, 0),
            (Permute, 5, 0),
            (Cutout1D, 0.2, 0)
        ]
    def __call__(self, ts):
        # ts: (B, C, T) or (C, T)
        is_single = len(ts.shape) == 2
        if is_single:
            ts = ts.unsqueeze(0)  # (C, T) -> (1, C, T)
            
        ops = random.choices(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            if random.random() < 0.5:
                ts = op(ts, v = self.m, max_v=max_v, bias=bias)
                
        if is_single:
            ts = ts.squeeze(0)  # (1, C, T) -> (C, T)
        return ts

def AddNoise(ts, v=9, max_v=0.3, bias=0):
    v = float(v)*max_v/10 + bias
    noise = torch.randn_like(ts)*v # Gaussian Noise
    return ts + noise

def Scale(ts, v=9, max_v = 1.8, bias=0.1):
    v = float(v)*max_v/10 + bias
    return ts*v

def TimeWarp(ts, v=9, max_v=0.3, bias=0):
    v = float(v)*max_v/10+bias
    B, C, T = ts.shape
    warped = torch.zeros_like(ts)
    for b in range(B):
        orig_steps = np.arange(T)
        warp_steps = np.sort(np.random.beta(v, v, size=T)) * (T-1)
        warped_indices = np.interp(orig_steps, warp_steps, orig_steps).astype(int)
        warped[b] = ts[b, :, warped_indices]
    return warped

def Cutout1D(ts, v=9, max_v = 0.2, bias=0):
    v = float(v)*max_v/10+bias
    B, C, T = ts.shape
    mask_len = int(v * T)
    if mask_len == 0:
        return ts
    ts = ts.clone()
    for b in range(B):
        start = np.random.randint(0, T - mask_len)
        ts[b, :, start:start+mask_len] = 0
    return ts

def Permute(ts, v = 5, max_v=5, bias = 0):
    """time series segment permutation"""
    # ts: (B, C, T)
    v = int(v * max_v / 10 + bias)
    if v <= 1:
        return ts
    B, C, T = ts.shape
    seg_size = T // max(v, 1)
    result = torch.zeros_like(ts)
    for b in range(B):
        segments = []
        start_idx = 0

        for i in range(v):
            if i==v-1:
                segments.append(ts[b, :, start_idx:])
            else:
                segments.append(ts[b, :, start_idx:start_idx+seg_size])
                start_idx += seg_size
        random.shuffle(segments)
        result[b] = torch.cat(segments, dim=1)
    return result


class Compose:
    def __init__(self, transforms):
        self.transforms = transforms
    def __call__(self, x):
        for t in self.transforms:
            x = t(x)
        return x

def get_data_transforms(dataset_configs):
    """
    dataset_configs
    - output_len : time series length of output
    """
    return{
        'train' : Compose([
            TSRandomHorizontalFlip(),
            RandAugmentTS(n=1, m=9),
        ]),
       'val' :Compose([]),
        'strong' : Compose([
            TSRandomHorizontalFlip(),
            RandAugmentTS(n=2, m=10)
        ]),
        'test' : Compose([]),
    }
    