import glob
import os
import math
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import collections
import random
from typing import Dict, Optional, Sequence

def seed_all(seed=1029, others: Optional[list] = None):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    # torch.use_deterministic_algorithms(True)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    if others is not None:
        if hasattr(others, "seed"):
            others.seed(seed)
            return True
        try:
            for item in others:
                if hasattr(item, "seed"):
                    item.seed(seed)
        except:
            pass

def test(env, policy, test_episode):
    goal_count = 0
    hole_count = 0
    for _ in range(test_episode):  
        observation = env.reset()
        done = False
        hole = False
        goal = False
        time_step = 0
        while not done and time_step < 30:
            if policy[observation].sum() == 0:
                action = np.random.choice(range(env.action_size), size=1)[0]
            else:
                action = np.random.choice(range(env.action_size), size=1, p=policy[observation])[0]
            new_observation, reward, cost, done, hole = env.step[observation][action]
            observation = new_observation
            time_step += 1
            if (done ==True):
                goal = True
                goal_count += 1
            if (hole == True):
                hole_count += 1
    return goal_count / test_episode, hole_count / test_episode

def Encode(env, num, type='state'):
    if type == 'state':
        one_hot = [0.0] * env.state_size
        one_hot[num] = 1
    elif type == 'action':
        one_hot = [0.0] * env.action_size
        one_hot[num] = 1
    return one_hot

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity) 

    def add(self, state, action, reward, next_state, done): 
        self.buffer.append((state, action, reward, next_state, done)) 

    def sample(self, batch_size): 
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done 

    def size(self): 
        return len(self.buffer)

class ValueIteration:
    def __init__(self, env, theta, gamma):
        self.env = env
        self.v = [0] * self.env.state_size
        self.theta = theta  
        self.gamma = gamma
        self.pi = [None for i in range(self.env.state_size)]

    def value_iteration(self):
        cnt = 0
        while 1:
            max_diff = 0
            new_v = [0] * self.env.state_size
            for s in range(self.env.state_size):
                qsa_list = []  
                for a in range(self.env.action_size):
                    qsa = 0
                    next_state, r, _, done, _ = self.env.step[s][a]
                    # print(next_state)   
                    qsa +=  (r + self.gamma * self.v[next_state] * (1 - done))
                    qsa_list.append(qsa)  
                new_v[s] = max(qsa_list)
                max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
            self.v = new_v
            if max_diff < self.theta: 
                break  
            cnt += 1
        self.get_policy()

    def get_policy(self):  
        for s in range(self.env.state_size):
            qsa_list = []
            for a in range(self.env.action_size):
                qsa = 0
                next_state, r, _, done, _ = self.env.step[s][a]
                qsa += (r + self.gamma * self.v[next_state] * (1 - done))
                qsa_list.append(qsa)
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)  
            self.pi[s] = [1 / cntq if q == maxq else 0 for q in qsa_list]

def get_f_div_fn(f_type: str):
    """
    Returns a function that computes the provided f-divergence type.
    """
    f_fn = None
    f_prime_inv_fn = None

    if f_type == 'chi2':
        f_fn = lambda x: 0.5 * (x - 1)**2
        f_prime_inv_fn = lambda x: x + 1

    elif f_type == 'softchi':
        f_fn = lambda x: torch.where(x < 1,
                                     x * (torch.log(x + 1e-10) - 1) + 1, 0.5 *
                                     (x - 1)**2)
        f_prime_inv_fn = lambda x: torch.where(x < 0, torch.exp(x.clamp(max=0.0)), x + 1)

    elif f_type == 'kl':
        f_fn = lambda x: x * torch.log(x + 1e-10)
        f_prime_inv_fn = lambda x: torch.exp(x - 1)
    else:
        raise NotImplementedError('Not implemented f_fn:', f_type)

    return f_fn, f_prime_inv_fn

def print_grad_param(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.grad.sum())
            
def sum_net_grad(model):
    grad_sum = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            grad_sum += abs(param.grad.sum())
    return grad_sum

def delete_folder(path):
    files = glob.glob(path + '/*')
    for f in files:
        os.remove(f)
        
def Input_Encode(style, state_size, action_size):
    state = torch.tensor(range(state_size), dtype=torch.float32)
    action = torch.tensor(range(action_size), dtype=torch.float32)
    if style == 'one_hot':
        state_one_hot = F.one_hot(state.to(torch.int64), num_classes=state_size).to(torch.float32)
        action_one_hot = F.one_hot(action.to(torch.int64), num_classes=action_size).to(torch.float32)
        state_encode = state_one_hot.repeat_interleave(action_size, dim=0)
        action_encode = action_one_hot.repeat(state_size, 1)
        return state_encode, action_encode
    elif style == 'binary':
        digit_state = math.ceil(math.log2(state_size))
        state_binary = torch.zeros((state_size, digit_state))
        for i in range(state_size):
            binary = bin(i)[2:].zfill(digit_state)
            state_binary[i] = torch.tensor([int(bit) for bit in binary], dtype=torch.float32)
        digit_act = math.ceil(math.log2(action_size))
        action_binary = torch.zeros((action_size, digit_act))
        for i in range(action_size):
            binary = bin(i)[2:].zfill(digit_act)
            action_binary[i] = torch.tensor([int(bit) for bit in binary], dtype=torch.float32)
        state_encode = state_binary.repeat_interleave(action_size, dim=0)
        action_encode = action_binary.repeat(state_size, 1)
        return state_encode, action_encode
    elif style == 'None':
        state_encode = state.reshape(-1, 1).repeat_interleave(action_size, dim=0)
        action_encode = action.reshape(-1, 1).repeat(state_size, 1)
        return state_encode, action_encode

def Loss_of_w(w, u_D_tensor, K_D_tensor, h_D_tensor, mu_0_tensor, lmbda, tau, gamma):
    lw1 = -torch.mm(u_D_tensor.reshape(1, -1), w)
    lw2 = lmbda * torch.norm(K_D_tensor @ w - (1-gamma)*mu_0_tensor, 1)
    lw3 = tau * torch.mm(h_D_tensor.reshape(1, -1), w)
    return  lw1, lw2, lw3, lw1 + lw2 + lw3



def plot_env(env):
    for i in range(env.nrow):
        for j in range(env.ncol):
            state = i * env.ncol + j
            if state == 0:
                print('S', end=' ')
            elif state in env.hole_state:
                print('H', end=' ')
            else:
                print('L', end=' ')
        print()

def plot_imitate_policy(env, w, random=True):
    action_meaning = ['<', 'v', '>', '^']
    for i in range(env.nrow):
        for j in range(env.ncol):
            state = i * env.ncol + j
            w_pi = w[state,:] / w[state,:].sum()
            if random:
                action = np.random.choice(range(env.action_size), size=1, p=w_pi)[0]
            else:
                action = np.argmax(w_pi)
            if state == 0:
                print('S', end=' ')
            elif state in env.hole_state:
                print('H', end=' ')
            elif state in env.wall_state:
                print('W', end=' ')
            else:
                print(action_meaning[action], end=' ')
        print()
        

            
