import copy
import math
import os
import random
import h5py
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from collections import defaultdict

from tqdm import tqdm


def get_data_DTDE(data_file):
    data = {}
    with h5py.File(data_file, 'r') as f:
            for key in f.keys():
                data[key] = (f[key][:])
    return data

def get_data_CTCE(agent_data_file):
    data = {}
    data['actions'], data['rewards'], data['costs'], data['terminals'], data['timeouts'], data['observations'], data['next_observations'] = [], [], [], [], [], [], []
    for file in agent_data_file:
        with h5py.File(file, 'r') as f:
                for key in f.keys():
                    data[key].append(f[key][:])
    data['actions'] = np.concatenate(data['actions'], axis=-1)
    data['observations'] = np.concatenate(data['observations'], axis=-1)
    data['next_observations'] = np.concatenate(data['next_observations'], axis=-1)
    data['rewards'] = np.mean(data['rewards'], axis=0)
    data['costs'] = np.mean(data['costs'], axis=0)
    data['terminals'] = np.bitwise_and.reduce(data['terminals'], axis=0)
    data['timeouts'] = data['timeouts'][0]
    
    return data
    
def get_data_CTDE(agent_data_file, agent_sg_file=None, same_r=False, same_c=False):
    data = {}
    data['actions'], data['rewards'], data['costs'], data['terminals'], data['timeouts'], data['observations'], data['next_observations'] = [], [], [], [], [], [], []
    state_dim, action_dim = [], []
    for file in agent_data_file:
        with h5py.File(file, 'r') as f:
                for key in f.keys():
                    data[key].append(f[key][:])
                    if key == 'observations':
                        state_dim.append(f[key].shape[-1])
                    elif key == 'actions':
                        action_dim.append(f[key].shape[-1])
    data['actions'] = np.concatenate(data['actions'], axis=-1)
    data['observations'] = np.concatenate(data['observations'], axis=-1)
    data['next_observations'] = np.concatenate(data['next_observations'], axis=-1)

    if same_r or same_c:
        if same_r:
            data['rewards'] = np.mean(data['rewards'], axis=0)
            data['rewards'] = np.repeat(data['rewards'][:, np.newaxis], len(agent_data_file), axis=1)
        else:
            data['rewards'] = np.stack(data['rewards'], axis=-1)
        if same_c:
            data['costs'] = np.mean(data['costs'], axis=0)
            data['costs'] = np.repeat(data['costs'][:, np.newaxis], len(agent_data_file), axis=1)
        else:
            data['costs'] = np.stack(data['costs'], axis=-1)
    else:
        data['rewards'] = np.stack(data['rewards'], axis=-1)
        data['costs'] = np.stack(data['costs'], axis=-1)

    data['terminals'] = np.bitwise_and.reduce(data['terminals'], axis=0)
    data['timeouts'] = data['timeouts'][0]

    if agent_sg_file is not None:
        data['sg_observations'] = []
        for file in agent_sg_file:
            with h5py.File(file, 'r') as f:
                for key in f.keys():
                    data[key].append(f[key][:])
        data['sg_observations'] = np.concatenate(data['sg_observations'], axis=-1)
    
    return data, state_dim, action_dim

def analyze_data(data, cost_limit):
    rewards = data['rewards']
    costs = data['costs']
    terminals = data['terminals']
    timeouts = data['timeouts']
    dones = np.bitwise_or(terminals, timeouts)
    
    ret, cost_ret = 0, 0
    ret_traj, cost_traj = [], []
    for i in range(len(dones)):
        ret += rewards[i]
        cost_ret += costs[i]
        if dones[i].item():
            ret_traj.append(ret)
            cost_traj.append(cost_ret)
            ret, cost_ret = 0, 0
    ret_traj = np.array(ret_traj, dtype=np.float32)
    cost_traj = np.array(cost_traj, dtype=np.float32)
    ret_span = ret_traj.max() - ret_traj.min()
    best_safe_ret = ret_traj[cost_traj <= cost_limit].max()
    if best_safe_ret >= 0:
        target_returns = [[[best_safe_ret], [cost_limit]], [[best_safe_ret * 0.5], [cost_limit]]]
    else:
        target_returns = [[[best_safe_ret], [cost_limit]], [[best_safe_ret * 2], [cost_limit]]]

    traj_num = ret_traj.shape[0]
    bin_size = max(round(math.sqrt(traj_num / 10)), 1)

    return ret_span, target_returns, bin_size

def analyze_data_CTDE(data, agent_num, cost_limit):
    print("Analyzing data...")

    rewards = data['rewards']
    costs = data['costs']
    terminals = data['terminals']
    timeouts = data['timeouts']
    dones = np.bitwise_or(terminals, timeouts)
    
    ret, cost_ret = [0 for _ in range(agent_num)], [0 for _ in range(agent_num)]
    ret_traj, cost_traj = [[] for _ in range(agent_num)], [[] for _ in range(agent_num)]
    for i in range(len(dones)):
        for agent in range(agent_num):
            ret[agent] += rewards[i, agent]
            cost_ret[agent] += costs[i, agent]
        if dones[i].item():
            for agent in range(agent_num):
                ret_traj[agent].append(ret[agent])
                cost_traj[agent].append(cost_ret[agent])
            ret, cost_ret = [0 for _ in range(agent_num)], [0 for _ in range(agent_num)]
    ret_traj = np.array(ret_traj, dtype=np.float32).transpose()
    cost_traj = np.array(cost_traj, dtype=np.float32).transpose()

    ret_mean, ret_std = [], []
    for agent in range(agent_num):
        ret_mean.append(ret_traj[:, agent].mean())
        ret_std.append(ret_traj[:, agent].std())

    mean_ret_traj = ret_traj.mean(axis=-1)
    mean_cost_traj = cost_traj.mean(axis=-1)
    ret_span = mean_ret_traj.max() - mean_ret_traj.min()
    best_safe_ret = mean_ret_traj[mean_cost_traj <= cost_limit].max()
    best_safe_index = np.where(mean_ret_traj == best_safe_ret)[0].item()
    best_safe_ret = [ret_traj[best_safe_index, agent] for agent in range(agent_num)]
    half_best_safe_ret = []
    for agent in range(agent_num):
        if best_safe_ret[agent] >= 0:
            half_best_safe_ret.append(best_safe_ret[agent] * 0.5)
        else:
            half_best_safe_ret.append(best_safe_ret[agent] * 2)
    target_returns = [[best_safe_ret, [cost_limit for _ in range(agent_num)]], [half_best_safe_ret, [cost_limit for _ in range(agent_num)]]]

    traj_num = ret_traj.shape[0]
    bin_size = max(round(math.sqrt(traj_num / 10)), 1)

    return ret_span, target_returns, bin_size, ret_mean, ret_std

