import os
import random
import torch
import pandas as pd
from attrdict import AttrDict
from tqdm import tqdm 


class MIMICSampler:
    def __init__(
        self,
        df,
        n: int, 
        save_dir: str,
        num_batches: int,
        batch_size: int = 16,
        device: str = "cuda",
        max_num_points: int = 200,
        min_num_points: int = 50,
        loop: bool = False,
    ):
        self.n = n
        self.df = df
        self.save_dir = save_dir
        self.num_batches = num_batches
        self.batch_size = batch_size
        self.device = device
        self.max_num_points = max_num_points
        self.min_num_points = min_num_points
        self.loop = loop

        save_name = f"mimic_{self.n}d_nb{num_batches}-bs{batch_size}-normalize.pt"
        self.save_path = os.path.join(save_dir, save_name)

        if not os.path.exists(self.save_path):
            os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
            self.batches = self.save()
        else:
            self.batches = torch.load(self.save_path)

    def prepare_patient_data(self, group, num_points):
        group = group.sort_values("CHARTTIME").reset_index(drop=True)
        x = group.drop(columns=["SUBJECT_ID", "CHARTTIME"]).values
        group['CHARTTIME'] = pd.to_datetime(group['CHARTTIME'], format='%Y-%m-%d %H:%M:%S', errors='coerce')

        t = (group["CHARTTIME"] - group["CHARTTIME"].min()).dt.total_seconds()
        t = t / t.max() * num_points
        return (
            torch.tensor(t, dtype=torch.float32).unsqueeze(-1), 
            torch.tensor(x, dtype=torch.float32),                
        )

    def generate_batch(self):
        batch = AttrDict()
        patient_samples = []

        grouped = list(self.df.groupby("SUBJECT_ID"))
        random.shuffle(grouped)

        num_ctx = torch.randint(
            low=self.min_num_points,
            high=self.max_num_points - self.min_num_points,
            size=[1],
        ).item()
        num_tar = torch.randint(
            low=self.min_num_points,
            high=self.max_num_points - num_ctx,
            size=[1],
        ).item()
        
        num_points = num_ctx + num_tar
            
        for _, group in grouped:
            
            t, x = self.prepare_patient_data(group, num_points)
            
            if x.size(0) < num_points: 
                continue

            xc = torch.cat([t[:num_ctx], x[:num_ctx, 1:self.n]], dim=-1)  
            xt = torch.cat([t[num_ctx:num_points], x[num_ctx:num_points, 1:self.n]], dim=-1) 
            yc = x[:num_ctx, :1]  
            yt = x[num_ctx:num_points, :1]  

            patient_samples.append((xc, xt, yc, yt))

            if len(patient_samples) == self.batch_size:
                break

        batch.xc = torch.stack([s[0] for s in patient_samples])  
        batch.xt = torch.stack([s[1] for s in patient_samples])  
        batch.yc = torch.stack([s[2] for s in patient_samples]) 
        batch.yt = torch.stack([s[3] for s in patient_samples])  

        batch.x = torch.cat([batch.xc, batch.xt], dim=1)  
        batch.y = torch.cat([batch.yc, batch.yt], dim=1)  

        return batch

    def save(self):
        batches = []
        for _ in tqdm(range(self.num_batches)):
            batch = self.generate_batch()
            for k, v in batch.items():
                batch[k] = v.cpu()
            batches.append(batch)
        torch.save(batches, self.save_path)
        return batches

    def __iter__(self):
        if self.loop:
            while True:
                for batch in self.batches:
                    for k, v in batch.items():
                        batch[k] = v.to(self.device)
                    yield batch
        else:
            for batch in self.batches:
                for k, v in batch.items():
                    batch[k] = v.to(self.device)
                yield batch

    def __len__(self):
        return len(self.batches)