import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import mmap
import pickle
from utils.dataprocess import DataPrepocessing


# ========================================== State-Action dataloader ==========================================

class SASADataset(Dataset):
    def __init__(self, data, env):
        self.pairs = []
        self.env = env
        self._process_data(data)
    
    
    def _process_data(self, data):
        for traj in data:
            for item in traj:
                if all(k in item for k in ['prev_state', 'prev_action', 'state', 'action']):
                    prev_state = self._to_tensor(item['prev_state'], torch.float32)
                    state = self._to_tensor(item['state'], torch.float32)
                    if self.env in ['badminton']:
                        prev_action = self._to_tensor(item['prev_action'], torch.float32)
                        action = self._to_tensor(item['action'], torch.float32)
                    else:
                        prev_action = self._to_tensor(item['prev_action'], torch.long)
                        action = self._to_tensor(item['action'], torch.long)
                    self.pairs.append((prev_state, prev_action, state, action))
                    
                
    @staticmethod
    def _to_tensor(data, dtype):
        return data if isinstance(data, torch.Tensor) else torch.tensor(data, dtype=dtype)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        prev_state, prev_action, state, action = self.pairs[idx]
        return prev_state, prev_action, state, action



class SASADataLoader:
    def __init__(self, env_name, pkl_path, player_name=None, batch_size=64, split_ratio=0.9):
        print(f'Begin processing state-action pairs...')
        self.env = env_name
        self.player_name = player_name
        self.train_loader, self.val_loader = self._create_dataloaders(pkl_path, batch_size, split_ratio)
        print('Finished processing data')
    
    
    def collate_fn(self, batch):
        prev_state, prev_action, state, action = zip(*batch)
        prev_state = torch.stack(prev_state)
        prev_action = torch.stack(prev_action)
        state = torch.stack(state)
        action = torch.stack(action)
        return prev_state, prev_action, state, action
       
    
    def _create_dataloaders(self, pkl_path, batch_size, split_ratio):
        data = self._load_large_pickle(pkl_path)
        dataset = SASADataset(data, self.env)
        train_size = int(split_ratio * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False, collate_fn=self.collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False, collate_fn=self.collate_fn)
        
        return train_loader, val_loader
    
    
    def _load_large_pickle(self, file_path):
        if self.env == 'badminton':
            df = pd.read_csv(file_path)
            data = DataPrepocessing(df, player_name = self.player_name)
        else:
            with open(file_path, 'rb') as f:
                mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
                data = pickle.load(mmapped_file)
                mmapped_file.close()
        return data
    
    
    def get_dataloader(self):
        return {
            "train": self.train_loader,
            "val": self.val_loader
        }



class SADataset(Dataset):
    def __init__(self, data, env):
        self.pairs = []
        self.env = env
        self._process_data(data)
    
    
    def _process_data(self, data):
        for traj in data:
            for item in traj:
                if all(k in item for k in ['state', 'action']):
                    state = self._to_tensor(item['state'], torch.float32)
                    if self.env in ['badminton']:
                        action = self._to_tensor(item['action'], torch.float32)
                    else:
                        action = self._to_tensor(item['action'], torch.long)
                    self.pairs.append((state, action))
                    
                
    @staticmethod
    def _to_tensor(data, dtype):
        return data if isinstance(data, torch.Tensor) else torch.tensor(data, dtype=dtype)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        state, action = self.pairs[idx]
        return state, action



class SADataLoader:
    def __init__(self, env_name, pkl_path, player_name=None, batch_size=64, split_ratio=0.9):
        print(f'Begin processing state-action pairs...')
        self.env = env_name
        self.player_name = player_name
        self.train_loader, self.val_loader = self._create_dataloaders(pkl_path, batch_size, split_ratio)
        print('Finished processing data')
    
    
    def collate_fn(self, batch):
        state, action = zip(*batch)
        state = torch.stack(state)
        action = torch.stack(action)
        return state, action
       
    
    def _create_dataloaders(self, pkl_path, batch_size, split_ratio):
        data = self._load_large_pickle(pkl_path)
        dataset = SADataset(data, self.env)
        train_size = int(split_ratio * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False, collate_fn=self.collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False, collate_fn=self.collate_fn)
        
        return train_loader, val_loader
    
    
    def _load_large_pickle(self, file_path):
        if self.env == 'badminton':
            df = pd.read_csv(file_path)
            data = DataPrepocessing(df, player_name = self.player_name)
        else:
            with open(file_path, 'rb') as f:
                mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
                data = pickle.load(mmapped_file)
                mmapped_file.close()
        return data
    
    
    def get_dataloader(self):
        return {
            "train": self.train_loader,
            "val": self.val_loader
        }
        
# ========================================== Trajs dataloader ==========================================

class TrajsDataset(Dataset):
    def __init__(self, data, env):
        self.trajs = []
        self.env = env
        self._process_data(data)
    
    def _process_data(self, data):
        for traj in data:
            states, actions = [], []
            for item in traj:
                if all(k in item for k in ['state', 'action']):
                    state = self._to_tensor(item['state'], torch.float32)
                    if self.env == 'badminton':
                        action = self._to_tensor(item['action'], torch.float32)
                    else:
                        action = self._to_tensor(item['action'], torch.long)

                    if state.dim() == 4 and state.shape[0] == 1:
                        state = state.squeeze(0)
                    states.append(state)
                    actions.append(action)
                    
            if states:
                self.trajs.append({'states': torch.stack(states),
                                   'actions': torch.stack(actions)})     # [len_traj, dim]
        
        if not self.trajs:
            raise ValueError("No data!")
            

    @staticmethod
    def _to_tensor(data, dtype):
        return data if isinstance(data, torch.Tensor) else torch.tensor(data, dtype=dtype)

    def __len__(self):
        return len(self.trajs)

    def __getitem__(self, idx):
        return self.trajs[idx]


class TrajsDataLoader:
    def __init__(self, env_name, pkl_path, player_name=None, batch_size=128, split_ratio=0.8, pad_length=None):
        print(f'Begin processing data for trajs...')
        self.env = env_name
        self.player_name = player_name
        self.train_loader, self.val_loader = self._create_dataloaders(pkl_path=pkl_path, 
                                                                      batch_size=batch_size, 
                                                                      split_ratio=split_ratio)
        self.pad_length = pad_length
        print('Finished processing data')
        print('----------------------------------------------')


    def pad_to_length(self, x, k, pad_value=-1):
        l = x.size(0)
        if l >= k:
            return x
        pad_len = k - l
        pad_dims = [0] * (2 * (x.dim() - 1)) + [0, pad_len]
        return F.pad(x, pad_dims, value=pad_value)
    

    def collate_fn(self, batch):
        """
        batch_dict = {}
        for key in batch[0]:
            batch_dict[key] = torch.stack([item[key] for item in batch], dim=0)  # [batch_size, len_traj, dim]
        """
        batch_state, batch_action = [], []
        k = self.pad_length
        for item in batch:
            state = self.pad_to_length(item["states"], k, pad_value=-1)
            action = self.pad_to_length(item["actions"], k, pad_value=-1)
            batch_state.append(state)
            batch_action.append(action)

        return {
            "states": torch.stack(batch_state),   # shape: [B, k, a, b, c]
            "actions": torch.stack(batch_action)  # shape: [B, k]
        }
    
    
    def _create_dataloaders(self, pkl_path, batch_size, split_ratio):
        data = self._load_large_pickle(pkl_path)
        dataset = TrajsDataset(data, self.env)
        
        train_size = int(split_ratio * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False, collate_fn=self.collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False, collate_fn=self.collate_fn)
        return train_loader, val_loader
    
    
    def _load_large_pickle(self, file_path):
        if self.env == 'badminton':
            df = pd.read_csv(file_path)
            data = DataPrepocessing(df, player_name = self.player_name)
        else:
            with open(file_path, 'rb') as f:
                mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
                data = pickle.load(mmapped_file)
                mmapped_file.close()
        return data
    
    
    def get_dataloader(self):
        return {
            "train": self.train_loader,
            "val": self.val_loader
        }
