import numpy as np
import argparse
import random

import gym
import d4rl
import hydra

import numpy as np
import torch

from omegaconf import DictConfig, OmegaConf
from buffer import ReplayBuffer
from logger import Logger
from trainer import MFPolicyTrainer
from agent import CQLAgent
import torch.optim as optim

from offline_rl.cql.utils import qlearning_dataset
from offline_rl.cql.save_data_z import load_model
from offline_rl.cql.networks import MLP, ActorProb, Critic
from offline_rl.cql.module import DiagGaussian
from offline_rl.cql.agent import BasePolicy, BCPolicy, bc_goal_policy, IQLPolicy

import os
import h5py
import torch.nn as nn


def calculate_time_dependent_mean(segments, max_length, action_dim):
    """
    Calculate the time-dependent mean vector for each timestep across all segments.
    
    Parameters:
    segments (list of numpy arrays): Each element of the list is a numpy array representing a trajectory segment with variable lengths.
    max_length (int): The maximum number of timesteps among all segments.
    action_dim (int): The dimension of each timestep (e.g., 9-dimensional vectors).
    
    Returns:
    time_means (numpy array): A [max_length, action_dim] array where each row is the mean vector at that timestep across all segments.
    """
    time_means = np.zeros((max_length, action_dim))  # Placeholder for the time-dependent mean vectors
    count_per_timestep = np.zeros(max_length)  # To count how many segments have data at each timestep
    
    for segment in segments:
        for t in range(len(segment)):
            time_means[t] += segment[t]  # Summing the vectors at each timestep
            count_per_timestep[t] += 1   # Counting the number of segments with data at this timestep
    
    # Calculate the mean by dividing by the number of segments with data at each timestep
    for t in range(max_length):
        if count_per_timestep[t] > 0:
            time_means[t] /= count_per_timestep[t]  # Calculate the mean only where data exists
    
    return time_means

def autocorrelation_lag_1_varying_length(segments):
    """
    Calculate the autocorrelation for lag=1 for each segment, where the mean vector at each timestep
    is calculated across all segments with data at that timestep.
    
    Parameters:
    segments (list of numpy arrays): Each element of the list is a numpy array representing
                                     a trajectory segment with variable lengths, each with 9-dimensional vectors.
    
    Returns:
    autocorrelations (list of floats): Autocorrelation values for each segment.
    """
    # Determine the maximum segment length and the action dimension (9 in this case)
    max_length = max(len(segment) for segment in segments)
    action_dim = segments[0].shape[1]  # Assuming all segments have the same action dimension (9)
    
    # Step 1: Calculate the time-dependent mean vector across all segments
    K = 3
    time_means = calculate_time_dependent_mean(segments, max_length, action_dim)
    
    autocorrelations = []

    # Step 2: Calculate autocorrelation for each segment using the global time-dependent mean
    for segment in segments:
        segment_length = len(segment)
        
        if segment_length < K + 1:
            # If the segment is too short to calculate autocorrelation, return NaN
            autocorrelations.append(np.nan)
            continue
        
        numerator = 0
        denominator = 0
        
        # Step 3: Calculate numerator and denominator for each segment using the global mean
        for t in range(segment_length - K):
            mean_t = time_means[t]  # Mean vector at timestep t (across all segments)
            mean_t1 = time_means[t + K]  # Mean vector at timestep t+1
            
            diff_t = segment[t] - mean_t
            diff_t1 = segment[t + K] - mean_t1
            
            # Numerator: Dot product of differences from the mean at consecutive timesteps
            numerator += np.dot(diff_t, diff_t1)
            
            # Denominator: Sum of squares (norm squared) of the difference from the mean
            denominator += np.dot(diff_t, diff_t)
        
        # Handle edge case where denominator is 0
        if denominator == 0:
            autocorrelation = np.nan  # Avoid division by zero
        else:
            autocorrelation =  numerator / denominator
        
        # Append result to the list of autocorrelations
        autocorrelations.append(autocorrelation)
    
    return autocorrelations

"""
suggested hypers
cql-weight=5.0, temperature=1.0 for all D4RL-Gym tasks
"""
def load_z_dataset(dataset, env):
    # data_path = os.path.join(data_dir, env+ ".h5")

    # z_dataset = {} 
    # with h5py.File(data_path, 'r') as h5_file:
    #     for key in h5_file.keys():
    #         z_dataset[key] = h5_file[key][()]
    z_dataset = qlearning_dataset(env = env, dataset = dataset)

    # z_dataset["state"] = np.zeros([3680, 60]) # np.load(os.path.join(dataset_dir, filename + "_states.npy"), allow_pickle=True)
    # z_dataset["latent_action"] = np.zeros([3680, 16]) # np.load(os.path.join(dataset_dir, filename + "_latents.npy"), allow_pickle=True)
    # z_dataset["action"] = np.zeros([3680, 60])# np.load(os.path.join(dataset_dir, filename + "_sT.npy"), allow_pickle=True)
    # z_dataset["next_state"] = np.zeros([3680, 60])
    # z_dataset["rewards"] = np.zeros([3680, 1])# np.load(os.path.join(dataset_dir, filename + "_rewards.npy"), allow_pickle=True)#(4*np.load(os.path.join(dataset_dir, filename + "_rewards.npy"), allow_pickle=True) - 30*4*0.5)/10 #zero-centering
    # z_dataset["terminals"] = np.zeros([3680, 1])

    return z_dataset

def divide_segments_by_latent_action(dataset):
    """
    Divide dataset["actions"] into segments based on changes in dataset["latent_action"].
    
    Parameters:
    dataset (dict): A dictionary containing "latent_action" and "actions" arrays.
    
    Returns:
    segments (list of numpy arrays): A list of segmented arrays from dataset["actions"].
    """
    latent_action = dataset["latent_action"].flatten()  # Flatten the latent_action array if necessary
    actions = dataset["actions"]
    
    segments = []
    start_idx = 0
    
    # for i in range(1, len(latent_action)):
    #     if latent_action[i] != latent_action[start_idx]:
    #         segments.append(actions[start_idx:i])
    #         start_idx = i

    num = actions.shape[0] // 9
    start_idx = 0
    for i in range(1, num):
        # if latent_action[i] != latent_action[start_idx]:
        segments.append(actions[start_idx:start_idx + 9])
            # start_idx = i
        start_idx = start_idx + 9


    segments.append(actions[start_idx:])
    
    return segments


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="cql")
    parser.add_argument("--load_high_policy", type=str, default=None)
    parser.add_argument("--eval_render", type=bool, default=False)
    parser.add_argument("--task", type=str, default="kitchen-complete-v0") # also change the env name in cfg
    parser.add_argument("--seed", type=int, default=211)
    parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256, 256])
    parser.add_argument("--actor-lr", type=float, default=1e-4)
    parser.add_argument("--critic-lr", type=float, default=3e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--alpha", type=float, default=0.2)
    parser.add_argument("--target-entropy", type=int, default=None)
    parser.add_argument("--auto-alpha", default=True)
    parser.add_argument("--alpha-lr", type=float, default=1e-4)

    parser.add_argument("--cql-weight", type=float, default=5.0)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--max-q-backup", type=bool, default=False)
    parser.add_argument("--deterministic-backup", type=bool, default=True)
    parser.add_argument("--with-lagrange", type=bool, default=False)
    parser.add_argument("--lagrange-threshold", type=float, default=10.0)
    parser.add_argument("--cql-alpha-lr", type=float, default=3e-4)
    parser.add_argument("--num-repeat-actions", type=int, default=10)

    # iql 
    parser.add_argument("--iql-hidden-dims", type=int, nargs='*', default=[256, 256])
    parser.add_argument("--iql-actor-lr", type=float, default=3e-4)
    parser.add_argument("--critic-q-lr", type=float, default=3e-4)
    parser.add_argument("--critic-v-lr", type=float, default=3e-4)
    parser.add_argument("--dropout_rate", type=float, default=0.1)
    parser.add_argument("--lr-decay", type=bool, default=True)
    parser.add_argument("--iql-gamma", type=float, default=0.99)
    parser.add_argument("--iql-tau", type=float, default=0.005)
    parser.add_argument("--expectile", type=float, default=0.7)
    parser.add_argument("--iql-temperature", type=float, default=0.5)


    parser.add_argument("--epoch", type=int, default=int(1000))
    parser.add_argument("--step-per-epoch", type=int, default=100000)
    parser.add_argument("--eval_episodes", type=int, default=100)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    return parser.parse_args()

@hydra.main(config_path="conf", config_name="save_config")
def train(cfg: DictConfig):
    # create env and dataset
    args=get_args()
    env = gym.make(args.task)
    dataset, embed, project_out, bc_policy, load_args = load_model(cfg)
    dataset = load_z_dataset(dataset, args.task)
    # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
    if 'antmaze' in args.task:
        dataset["rewards"] = (dataset["rewards"] - 0.5) * 4.0
    args.obs_shape = env.observation_space.shape
    args.action_dim = np.unique(dataset["latent_action"], axis=0).shape[0] # np.unique(dataset["latent_action"], axis=1, return_inverse=True)[1].shape[0]
    ## for discrete action 
    embed_index_set, dataset["latent_action"] = np.unique(dataset["latent_action"], axis = 0, return_inverse = True)
    embed_index_set = embed_index_set.astype("int64")
    dataset["latent_action"] = dataset["latent_action"].astype("int64")

    dataset = {
        "latent_action": dataset["latent_action"],  # Example latent action
        "actions": dataset["actions"]  # Example actions with 9 dimensions
    }
    segments = divide_segments_by_latent_action(dataset)

    autocorrelations = autocorrelation_lag_1_varying_length(segments)

    # Print the autocorrelation for each segment
    for idx, autocorr in enumerate(autocorrelations):
        print(f"Segment {idx + 1} Autocorrelation (lag 1): {autocorr}")

if __name__ == "__main__":
    train()
