import csv
import math
import gymnasium as gym
from gymnasium import error, utils
from gymnasium.utils import seeding
from gymnasium.envs.registration import register
from gymnasium import spaces
import numpy as np
import random
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from torch.nn import functional as F
import time
import math
import copy
import envs.py222 as py222
import random

class cube(gym.Env):
    metadata = {'render.modes': ['human']}


    def __init__(self,episode_steps=1000,scramble_steps=20,random_length=False,cube_cam="orthographic",seed=0):
        self.episode_steps = episode_steps
        self.scramble_steps = scramble_steps
        self.random_length = random_length
        self.cube_cam = cube_cam
        self.actions_list = ["U","U'","R","R'","F","F'","D","D'","L","L'","B","B'"]
        
        self.steps = 0
        self.goal_state = py222.initState()
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        
        self.camera_view = 0
        if self.cube_cam=="full":
            self.observation_space = spaces.Box(low=0, high=5, shape=(len(self.goal_state),), dtype=np.uint8)
            self.camera_action_views = lambda action,camera_view: camera_view
            self.camera_views = [[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]]
        elif cube_cam == "face":
            self.actions_list = ["U","U'","R","R'","F","F'","D","D'","L","L'","B","B'","CU","CR","CD","CL"]
            self.camera_action_views = lambda action,camera_view:  {12: [5,0,0,2,0,0], # CU
                                                                    13: [1,2,1,1,2,4], # CR
                                                                    14: [2,3,3,0,3,3], # CD
                                                                    15: [4,5,4,4,5,1], # CL
                                                                    }[action][camera_view]
            self.camera_views = [[0,1,2,3],[4,5,6,7],[8,9,10,11],[12,13,14,15],[16,17,18,19],[20,21,22,23]]
            self.observation_space = spaces.Box(low=0, high=5, shape=(4,), dtype=np.uint8)
        else:
            self.actions_list = ["U","U'","R","R'","F","F'","D","D'","L","L'","B","B'","C"]
            self.camera_action_views = lambda action,camera_view: (camera_view+1)%2
            self.camera_views = [[0,1,2,3,4,5,6,7,8,9,10,11],[12,13,14,15,16,17,18,19,20,21,22,23]]
            self.observation_space = spaces.Box(low=0, high=5, shape=(4*3,), dtype=np.uint8)
        self.action_space = spaces.Discrete(len(self.actions_list),seed=seed)



    def step(self, action_ind):
        self.steps += 1
        if action_ind>=12:
            if type(action_ind) != int: action_ind = int(action_ind)
            self.camera_view = self.camera_action_views(action_ind,self.camera_view)
        else:
            action_str = self.actions_list[action_ind]
            self.state = py222.doAlgStr(self.state, action_str)
        diff = abs(self.goal_state-self.state).sum()

        if diff == 0:
            reward = 1.0
            done = True
        else:
            reward = 0.0
            done = False

        truncate = False
        if self.steps >= self.episode_steps:
            truncate = True
            
        obs = self.state[self.camera_views[self.camera_view]]

        return obs, reward, done, truncate, {}


    def reset(self,seed=None,options=None):
        scramble_steps = self.scramble_steps
        if self.random_length:
            scramble_steps = random.randint(1,self.scramble_steps)

        self.steps = 0
        state = py222.initState()
        for i in range(scramble_steps):
            action_ind = random.randint(0,len(self.actions_list)-1)
            action_str = self.actions_list[action_ind]
            state = py222.doAlgStr(state, action_str)

        self.state = state
        self.camera_view = 0
        obs = self.state[self.camera_views[self.camera_view]]
        return obs, {}

    def evaluate(self,policy,scramble_steps=1000):
        reward_list = []
        step_list = []
        for i in range(100):
            state = self.reset(scramble_steps=scramble_steps)
            done = False
            steps = 0
            for k in range(1000):
                if not done:
                    action = policy.act(state)
                    next_state, reward, done, info = self.step(action)
                    # print("episode=",i,"step=",k,"s=",state,"a=",action,"r=",reward,"ns=",next_state)
                    steps += 1
                    state = copy.deepcopy(next_state)

            step_list.append(steps)
            reward_list.append(reward)

        return np.array(reward_list).mean(), np.array(step_list).mean()

    def evaluateoc(self,policy,scramble_steps=1000):
        reward_list = []
        step_list = []
        for i in range(100):
            state = self.reset(scramble_steps=scramble_steps)
            done = False
            steps = 0
            old_option = 0
            for k in range(1000):
                if not done:
                    action, option = policy.act(state,old_option)
                    next_state, reward, done, info = self.step(action)
                    # print("episode=",i,"step=",k,"s=",state,"a=",action,"r=",reward,"ns=",next_state)
                    steps += 1
                    state = copy.deepcopy(next_state)
                    old_option = copy.deepcopy(option)

            step_list.append(steps)
            reward_list.append(reward)

        return np.array(reward_list).mean(), np.array(step_list).mean()


register(
    id='cube-v0',
    entry_point='envs.cube2x2:cube',
)
