from copy import deepcopy

import gym
import torch
import numpy as np

from preflow import FlowMatching


class FlowPolicy:
    def __init__(
        self, 
        env: gym.Env, 
        flow: FlowMatching, 
        policy: torch.nn.Module, 
        action_idx: int=0,
        max_action: float=1,
        seg_len: int=30, 
        rollout_frequency: int=1,
        flow_iteration: int=1,
        num_candidates: int=1,
        find_best: bool=False,
        no_flow: bool=False,
        device: torch.device='cpu'
        ) -> None:
        
        self.env = env
        self.flow = flow
        self.policy = policy
        self.device = device
        self.seg_len = seg_len
        self.action_idx = action_idx
        self.max_action = max_action
        
        self.rollout_frequency = rollout_frequency
        self.current_rollout = rollout_frequency
        self.flow_iteration = flow_iteration
        self.stored_actions = None
        
        self.num_candidates = num_candidates
        self.find_best = find_best
        self.no_flow = no_flow
        
        self.policy.to(device)
        self.policy.eval()
        
    def __call__(
        self, 
        start_state: np.ndarray, 
        qpos: np.ndarray, 
        qvel: np.ndarray, 
        **kwargs
        ) -> np.ndarray:
        
        if self.current_rollout >= self.rollout_frequency:
            
            candidates = []
            
            for _ in range(self.num_candidates): 
                _ = self.env.reset()
                self.env.set_state(qpos, qvel)
                
                done = False
                actions = []
                state = np.copy(start_state)
                for _ in range(self.seg_len):
                    if not done:
                        with torch.no_grad():
                            state = torch.tensor(state).float().to(self.device)
                            action = self.policy(state).cpu().numpy()
                        state, _, done, _ = self.env.step(action)
                        actions.append(action)
                    else:
                        actions.append(np.zeros(self.env.action_space.shape))
                
                candidates.append(actions)
                
            if self.num_candidates > 1:
                if self.find_best:
                    best_reward = -np.inf
                else:
                    best_reward = np.inf
                best_candidate = candidates[0]
                
                for i, candidate in enumerate(candidates):
                    _ = self.env.reset()
                    self.env.set_state(qpos, qvel)
                    
                    step = 0
                    done = False
                    episode_return = 0
                    state = np.copy(start_state)
                    while not done:
                        with torch.no_grad():
                            state = torch.tensor(state).float().to(self.device)
                            action = candidate[step]
                        state, reward, done, _ = self.env.step(action)
                        step += 1
                        episode_return += reward
                        if step >= self.seg_len:
                            break
                        
                    if step < self.seg_len:
                        continue
                        
                    if episode_return > best_reward and self.find_best:
                        best_reward = episode_return
                        best_candidate = candidate
                    elif episode_return < best_reward and not self.find_best:
                        best_reward = episode_return
                        best_candidate = candidate
                
                actions = best_candidate
                
            start_state = torch.tensor(start_state).float().to(self.device)
            actions = torch.tensor(np.array(actions)).flatten().float().to(self.device)
            if not self.no_flow:
                for _ in range(self.flow_iteration):
                    actions = self.flow.compute_target(actions, context=start_state, **kwargs)
            actions = actions.view(-1, self.env.action_space.shape[0])
            self.stored_actions = actions
            self.current_rollout = self.action_idx
        
        action = self.stored_actions[self.current_rollout, :]
        action = action.clip(-self.max_action, self.max_action).cpu().numpy()
        self.current_rollout += 1
        return action
    
    def to(self, device):
        self.device = device
        self.policy.to(device)
        return self