import random
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt

class ENV:
    def __init__(self, H):
        self.H = H
        self.d_state = self.H+1 + H-2+2*(H-2)+1
        self.n_act = 2

        self.t = 0
        self.state= None
        self.b = [random.randint(0,1) for _ in range(self.H-2)]
    
        self.deterministic_right = False
    
    def reset(self):
        self.t = 0
        self.state = [[0]*(self.H-2),[(0,0)]*(self.H-2),0]
    
    def optimal_state(self):
        s = [0]*((self.H-2)*3+1)
        for i in range(self.H-2):
            s[self.H-2+2*i] = 1

        return s
    
    def observe_state(self):
        o = copy.deepcopy(self.state[0])
        for i in range(self.H-2):
            o.append(self.state[1][i][0])
            o.append(self.state[1][i][1])
        o.append(self.state[2])

        time_step = [0.]*(self.H+1)
        time_step[self.t] = 1.
        
        return time_step+o
    
    def observe_state_without_timestep(self):
        o = copy.deepcopy(self.state[0])
        for i in range(self.H-2):
            o.append(self.state[1][i][0])
            o.append(self.state[1][i][1])
        o.append(self.state[2])
        
        return o

    def finished(self):
        return self.t==self.H
    
    def step(self, a):
        r = 0.
        
        if self.t==0:
            if self.deterministic_right:
                self.state[2] = 1
            else:
                self.state[2] = int(random.random()<=1/2)
        elif self.t==self.H-1:
            if self.state[2]==1:
                ok_reward = True
                for i in range(self.H-2):
                    bi = self.b[i]
                    if not self.state[0][i]==bi:
                        ok_reward = False
                        self.state[1][i] = (0,1)
                    else:
                        self.state[1][i] = (1,0)
                    self.state[0][i] = 0
                self.state[2] = 0
                if ok_reward:
                    r = 1.
            else:
                self.state = ([0]*(self.H-2),[(1,0)]*(self.H-2),0)
                r = 1.
        else:
            if self.state[2]==1:
                self.state[0][self.t-1] = a
            

        self.t += 1
        
        return r,self.observe_state()


    def sample_trajectory(env, policy, simulator=False):
        env.reset()
        states = []
        actions = []
        rewards = []
        policy_data = [] # collect act_probs data
        while not env.finished():
            state = env.observe_state()
            states.append(state)
            with torch.no_grad():
                if not simulator:
                    a = policy(torch.tensor(state))
                else:
                    a,policy_data_ = policy(copy.deepcopy(env))
                    policy_data.append(policy_data_)
            r,_ = env.step(a)
            actions.append(a)
            rewards.append(r)
        states.append(env.observe_state())
        returns = [sum(rewards)]*env.H

        return states, actions, rewards, returns, policy_data

    def sample_full_trajectory(env, policy, return_reward=False):
        env.reset()
        trajectory = []
        first = True
        t = 0
        while not env.finished():
            state = env.observe_state_without_timestep()
            trajectory += state
            if not first:
                trajectory.append(r)
            first = False
            with torch.no_grad():
                a = policy((t,torch.tensor(state)))
            trajectory.append(a)
            r,_ = env.step(a)
            t+=1
            
        trajectory += env.observe_state_without_timestep()
        trajectory.append(r)

        if return_reward:
            return trajectory,r
        else:
            return trajectory # state,a,state',r,a,state'',r,...
    
    def sample_full_trajectories(env, policy, N):
        trajectories = [env.sample_full_trajectory(policy) for _ in range(N)]

        return torch.tensor(trajectories, dtype=torch.long)



    def collect_data(env, policy, N, detect_optimal_trajectory=False, simulator=False):
        D = {'states':[], 'actions':[], 'rewards':[], 'next_states':[], 'end_states':[], 'returns':[], 'policy_data':[]}
        optimal_trajectory_discovered= False
        for i in range(N):
            states, actions, rewards, returns, policy_data = env.sample_trajectory(policy, simulator)
            if detect_optimal_trajectory and sum(rewards)==1. and states[-2][-1]==1:
                optimal_trajectory_discovered = True
            D['states'].append(torch.tensor(states[:-1]))
            D['actions'].append(torch.tensor(actions))
            D['rewards'].append(torch.tensor(rewards))
            D['next_states'].append(torch.tensor(states[1:]))
            D['end_states'].append(torch.tensor(states[-1]).repeat((env.H,1)))
            D['returns'].append(torch.tensor(returns))
            D['policy_data'].append(torch.tensor(policy_data))

        D['states'] = torch.stack(D['states']).view(-1,env.d_state).float()
        D['actions'] = torch.stack(D['actions']).view(-1)
        D['rewards'] = torch.stack(D['rewards']).view(-1).float()
        D['next_states'] = torch.stack(D['next_states']).view(-1,env.d_state).float()
        D['end_states'] = torch.stack(D['end_states']).view(-1,env.d_state).float()
        D['returns'] = torch.stack(D['returns']).view(-1).float()
        D['policy_data'] = torch.stack(D['policy_data']).view(-1,env.n_act).float()
        R = D['rewards'].sum().item()/N

        if detect_optimal_trajectory:
            return D, R, optimal_trajectory_discovered
        else:
            return D, R

    def sample_policy_trajectory_print(env, policy):
        env.reset()
        states = []
        actions = []
        t = 0
        while not env.finished():
            state = env.observe_state()
            states.append(state)
            print(state)
            a = policy(torch.tensor(state))
            r,_ = env.step(a)
            actions.append(a)
            t += 1
        states.append(env.observe_state())
        print(env.observe_state())
        print(r)

        return states, actions

    def eval(self, agent, N=1000, simulator=False):
        _,R,optimal_trajectory_discovered = self.collect_data(agent.policy, N, detect_optimal_trajectory=True, simulator=simulator)

        print('average returns : '+str(R))
        print('optimal trajectory discovered : '+str(optimal_trajectory_discovered))

        return optimal_trajectory_discovered
