from torch.utils.data import Dataset
from collections import deque, namedtuple
import torch as torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import d4rl

class OfflineDataset():
    def __init__(self, device, env, normalize = True):
        self.env = env
        self.device = device
        data = env.get_dataset()
        observations = data['observations']
        actions = data['actions']
        rewards = data['rewards']
        terminals = data['terminals']
        timeouts = data['timeouts']
        if not terminals[-1]:
            timeouts[-1]=True

        self.data_s1_f_b = torch.tensor(observations, dtype=torch.float32).to(self.device)
        self.data_a1_f_b = torch.tensor(actions, dtype=torch.float32).to(self.device)
        self.n_data_f_b = len(self.data_s1_f_b)
                    

        # Initialize empty lists for each data field
        data_s1 = []
        data_s2 = []
        data_a1 = []
        data_a2 = []
        data_r = []
        data_d = []
        data_start = []  # New list for episode start indicators

        # Initialize episode start indicator
        episode_start = True

        # Iterate through the transitions
        for i in range(0, len(observations)):
            # Skip the last state of each episode that ended due to a timeout
            if timeouts[i]:
                episode_start = True
                continue

            data_s1.append(observations[i])
            if i == len(observations)-1:
                data_s2.append(observations[0])
                data_a2.append(actions[0])
            else:
                data_s2.append(observations[i + 1])
                data_a2.append(actions[i + 1])
            data_a1.append(actions[i])
            data_r.append(rewards[i])
            data_d.append(terminals[i])
            data_start.append(episode_start)  # Add the episode start indicator to the list

            episode_start = False  # Reset the episode start indicator
            if terminals[i]:
                episode_start = True

        # Convert lists to numpy arrays
        self.data_s1_q = torch.tensor(np.array(data_s1), dtype=torch.float32).to(self.device)
        self.data_s2_q = torch.tensor(np.array(data_s2), dtype=torch.float32).to(self.device)
        self.data_a1_q = torch.tensor(np.array(data_a1), dtype=torch.float32).to(self.device)
        self.data_a2_q = torch.tensor(np.array(data_a2), dtype=torch.float32).to(self.device)
        self.data_r_q = torch.tensor(np.array(data_r), dtype=torch.float32).to(self.device).unsqueeze(1)
        self.data_d_q = torch.tensor(np.array(data_d)).to(self.device).to(torch.int).unsqueeze(1)
        self.data_start_q = torch.tensor(np.array(data_start)).to(self.device).to(torch.int).unsqueeze(1) # Convert the episode start indicator list to a numpy array

        self.n_data_q = len(self.data_s1_q)


        data_s1 = []
        data_s2 = []
        data_a1 = []
        data_r = []
        data_start = []  # Keep the list for episode start indicators

        # Initialize episode start indicator
        episode_start = True

        # Iterate through the transitions
        for i in range(0, len(observations)):
            # Skip the last state of each episode
            if terminals[i] or timeouts[i]:
                episode_start = True
                continue

            data_s1.append(observations[i])
            data_s2.append(observations[i + 1])
            data_a1.append(actions[i])
            data_r.append(rewards[i])
            data_start.append(episode_start)  # Add the episode start indicator to the list

            episode_start = False  # Reset the episode start indicator

        # Convert lists to numpy arrays
        self.data_s1_f_m = torch.tensor(np.array(data_s1), dtype=torch.float32).to(self.device)
        self.data_s2_f_m = torch.tensor(np.array(data_s2), dtype=torch.float32).to(self.device)
        self.data_a1_f_m = torch.tensor(np.array(data_a1), dtype=torch.float32).to(self.device)
        self.data_r_f_m = torch.tensor(np.array(data_r), dtype=torch.float32).to(self.device).unsqueeze(1)
        self.data_start_f_m = torch.tensor(np.array(data_start)).to(self.device).to(torch.int).unsqueeze(1)  # Convert the episode start indicator list to a numpy array

        self.n_data_f_m = len(self.data_s1_f_m)


        self.state_mean = torch.mean(self.data_s1_f_m, dim=0, keepdim=True).to(self.device)
        self.state_std = torch.std(self.data_s1_f_m, dim=0, keepdim=True).to(self.device)
        self.action_mean = torch.mean(self.data_a1_f_m, dim=0, keepdim=True).to(self.device)
        self.action_std = torch.std(self.data_a1_f_m, dim=0, keepdim=True).to(self.device)
        self.reward_mean = torch.mean(self.data_r_f_m, dim=0, keepdim=True).to(self.device)
        self.reward_std = torch.std(self.data_r_f_m, dim=0, keepdim=True).to(self.device)

        if normalize:
            self.data_s1_f_b -= self.state_mean
            self.data_s1_f_b /= self.state_std
            self.data_a1_f_b -= self.action_mean
            self.data_a1_f_b /= self.action_std
            self.data_s1_f_m -= self.state_mean
            self.data_s1_f_m /= self.state_std
            self.data_a1_f_m -= self.action_mean
            self.data_a1_f_m /= self.action_std
            self.data_s2_f_m -= self.state_mean
            self.data_s2_f_m /= self.state_std
            self.data_r_f_m -= self.reward_mean
            self.data_r_f_m /= self.reward_std
            self.data_s1_q -= self.state_mean
            self.data_s1_q /= self.state_std
            self.data_a1_q -= self.action_mean
            self.data_a1_q /= self.action_std
            self.data_s2_q -= self.state_mean
            self.data_s2_q /= self.state_std
            self.data_a2_q -= self.action_mean
            self.data_a2_q /= self.action_std
            self.data_r_q -= self.reward_mean
            self.data_r_q /= self.reward_std
        
        

