import math
import torch
import numpy as np
import wandb
import random, pdb

def get_sample_duration():
    N = 1
    a = 2
    return np.random.zipf(a, N)[0]




def greedy_action(
    self,
    sample,
    eps_threshold,
    preprocessed_obs,
    num_frames,
    hidden_states
):
    new_hidden_states = hidden_states
    with torch.no_grad():
        if hidden_states and "q" in hidden_states:
            hidden_state = hidden_states["q"]
            Q, hidden_state = self.policy_network(preprocessed_obs, hidden_state)
            new_hidden_states = {"q": hidden_state, "rnd": hidden_states["rnd"]}
        else:
            Q = self.policy_network(preprocessed_obs)
            if self.reset_multi == 2:
                Q2 = self.policy_network2(preprocessed_obs)

                if self.reset_time < 0:
                    if torch.rand(1) < 0.5:
                        action = torch.argmax(Q).item()
                    else:
                        action = torch.argmax(Q2).item()
                        
                elif self.reset_time == 0:
                    a1 = torch.argmax(Q).item()
                    a2 = torch.argmax(Q2).item()
                    probs = torch.nn.functional.softmax(torch.tensor([Q2[:, a1], Q2[:, a2]])/max(1e-4, max(Q2[:, a1].item(), Q2[:, a2].item()))*self.reset_ww, dim=0)
                    chosen_index = torch.distributions.Categorical(probs).sample()
                    if chosen_index == 0:
                        action = a1
                    else:
                        action = a2

                elif self.reset_time == 1:
                    a1 = torch.argmax(Q).item()
                    a2 = torch.argmax(Q2).item()

                    probs = torch.nn.functional.softmax(torch.tensor([Q[:, a1], Q[:, a2]])/max(1e-4, max(Q[:, a1].item(), Q[:, a2].item()))*self.reset_ww, dim=0)
                    chosen_index = torch.distributions.Categorical(probs).sample()
                    if chosen_index == 0:
                        action = a1
                    else:
                        action = a2
            elif self.reset_multi == 4:
                Q2 = self.policy_network2(preprocessed_obs)
                Q3 = self.policy_network3(preprocessed_obs)
                Q4 = self.policy_network4(preprocessed_obs)

                if self.reset_time < 0:
                    aa = torch.rand(1)
                    if aa < 0.25:
                        action = torch.argmax(Q).item()
                    elif aa < 0.5:
                        action = torch.argmax(Q2).item()
                    elif aa < 0.75:
                        action = torch.argmax(Q3).item()
                    else:
                        action = torch.argmax(Q4).item()
                        
                elif self.reset_time == 0:
                    a1 = torch.argmax(Q).item()
                    a2 = torch.argmax(Q2).item()
                    a3 = torch.argmax(Q3).item()
                    a4 = torch.argmax(Q4).item()
                    
                    probs = torch.nn.functional.softmax(torch.tensor([Q2[:, a1], Q2[:, a2],Q2[:, a3], Q2[:, a4]])/max(1e-4, max(Q2[:, a1].item(), Q2[:, a2].item(), Q2[:, a3].item(), Q2[:, a4].item()))*self.reset_ww, dim=0)
                    chosen_index = torch.distributions.Categorical(probs).sample()
                    if chosen_index == 0:
                        action = a1
                    elif chosen_index == 1:
                        action = a2
                    elif chosen_index == 2:
                        action = a3
                    else:
                        action = a4

                elif self.reset_time == 1:
                    a1 = torch.argmax(Q).item()
                    a2 = torch.argmax(Q2).item()
                    a3 = torch.argmax(Q3).item()
                    a4 = torch.argmax(Q4).item()

                    probs = torch.nn.functional.softmax(torch.tensor([Q3[:, a1], Q3[:, a2],Q3[:, a3], Q3[:, a4]])/max(1e-4, max(Q3[:, a1].item(), Q3[:, a2].item(), Q3[:, a3].item(), Q3[:, a4].item()))*self.reset_ww, dim=0)
                    chosen_index = torch.distributions.Categorical(probs).sample()
                    if chosen_index == 0:
                        action = a1
                    elif chosen_index == 1:
                        action = a2
                    elif chosen_index == 2:
                        action = a3
                    else:
                        action = a4
                elif self.reset_time == 2:
                    a1 = torch.argmax(Q).item()
                    a2 = torch.argmax(Q2).item()
                    a3 = torch.argmax(Q3).item()
                    a4 = torch.argmax(Q4).item()

                    probs = torch.nn.functional.softmax(torch.tensor([Q4[:, a1], Q4[:, a2],Q4[:, a3], Q4[:, a4]])/max(1e-4, max(Q4[:, a1].item(), Q4[:, a2].item(), Q4[:, a3].item(), Q4[:, a4].item()))*self.reset_ww, dim=0)
                    chosen_index = torch.distributions.Categorical(probs).sample()
                    if chosen_index == 0:
                        action = a1
                    elif chosen_index == 1:
                        action = a2
                    elif chosen_index == 2:
                        action = a3
                    else:
                        action = a4
                elif self.reset_time == 3:
                    a1 = torch.argmax(Q).item()
                    a2 = torch.argmax(Q2).item()
                    a3 = torch.argmax(Q3).item()
                    a4 = torch.argmax(Q4).item()

                    probs = torch.nn.functional.softmax(torch.tensor([Q[:, a1], Q[:, a2],Q[:, a3], Q[:, a4]])/max(1e-4, max(Q[:, a1].item(), Q[:, a2].item(), Q[:, a3].item(), Q[:, a4].item()))*self.reset_ww, dim=0)
                    chosen_index = torch.distributions.Categorical(probs).sample()
                    if chosen_index == 0:
                        action = a1
                    elif chosen_index == 1:
                        action = a2
                    elif chosen_index == 2:
                        action = a3
                    else:
                        action = a4
            else:
                action = torch.argmax(Q).item() 
                
        # action = torch.argmax(Q).item()
        
    return action, new_hidden_states


def epsilon_action(
    self,
    sample,
    eps_threshold,
    preprocessed_obs,
    num_frames,
    hidden_states
):
    new_hidden_states = hidden_states
    if sample > eps_threshold:
        with torch.no_grad():
            if hidden_states and "q" in hidden_states:
                hidden_state = hidden_states["q"]
                Q, hidden_state = self.policy_network(preprocessed_obs, hidden_state)
                new_hidden_states = {"q": hidden_state, "rnd": hidden_states["rnd"]}
            else:
                Q = self.policy_network(preprocessed_obs)
                
                        
                if self.reset_multi == 2:
                    Q2 = self.policy_network2(preprocessed_obs)

                    if self.reset_time < 0:
                        if torch.rand(1) < 0.5:
                            action = torch.argmax(Q).item()
                        else:
                            action = torch.argmax(Q2).item()
                            
                    elif self.reset_time == 0:
                        a1 = torch.argmax(Q).item()
                        a2 = torch.argmax(Q2).item()
                        probs = torch.nn.functional.softmax(torch.tensor([Q2[:, a1], Q2[:, a2]])/max(1e-4, max(Q2[:, a1].item(), Q2[:, a2].item()))*self.reset_ww, dim=0)
                        chosen_index = torch.distributions.Categorical(probs).sample()
                        if chosen_index == 0:
                            action = a1
                        else:
                            action = a2

                    elif self.reset_time == 1:
                        a1 = torch.argmax(Q).item()
                        a2 = torch.argmax(Q2).item()

                        probs = torch.nn.functional.softmax(torch.tensor([Q[:, a1], Q[:, a2]])/max(1e-4, max(Q[:, a1].item(), Q[:, a2].item()))*self.reset_ww, dim=0)
                        chosen_index = torch.distributions.Categorical(probs).sample()
                        if chosen_index == 0:
                            action = a1
                        else:
                            action = a2
                elif self.reset_multi == 4:
                    Q2 = self.policy_network2(preprocessed_obs)
                    Q3 = self.policy_network3(preprocessed_obs)
                    Q4 = self.policy_network4(preprocessed_obs)
                    
                    # pdb.set_trace()

                    if self.reset_time < 0:
                        aa = torch.rand(1)
                        if aa < 0.25:
                            action = torch.argmax(Q).item()
                        elif aa < 0.5:
                            action = torch.argmax(Q2).item()
                        elif aa < 0.75:
                            action = torch.argmax(Q3).item()
                        else:
                            action = torch.argmax(Q4).item()
                            
                    elif self.reset_time == 0:
                        a1 = torch.argmax(Q).item()
                        a2 = torch.argmax(Q2).item()
                        a3 = torch.argmax(Q3).item()
                        a4 = torch.argmax(Q4).item()
                        
                        probs = torch.nn.functional.softmax(torch.tensor([Q2[:, a1], Q2[:, a2],Q2[:, a3], Q2[:, a4]])/max(1e-4, max(Q2[:, a1].item(), Q2[:, a2].item(), Q2[:, a3].item(), Q2[:, a4].item()))*self.reset_ww, dim=0)
                        chosen_index = torch.distributions.Categorical(probs).sample()
                        if chosen_index == 0:
                            action = a1
                        elif chosen_index == 1:
                            action = a2
                        elif chosen_index == 2:
                            action = a3
                        else:
                            action = a4

                    elif self.reset_time == 1:
                        a1 = torch.argmax(Q).item()
                        a2 = torch.argmax(Q2).item()
                        a3 = torch.argmax(Q3).item()
                        a4 = torch.argmax(Q4).item()

                        probs = torch.nn.functional.softmax(torch.tensor([Q3[:, a1], Q3[:, a2],Q3[:, a3], Q3[:, a4]])/max(1e-4, max(Q3[:, a1].item(), Q3[:, a2].item(), Q3[:, a3].item(), Q3[:, a4].item()))*self.reset_ww, dim=0)
                        chosen_index = torch.distributions.Categorical(probs).sample()
                        if chosen_index == 0:
                            action = a1
                        elif chosen_index == 1:
                            action = a2
                        elif chosen_index == 2:
                            action = a3
                        else:
                            action = a4
                    elif self.reset_time == 2:
                        a1 = torch.argmax(Q).item()
                        a2 = torch.argmax(Q2).item()
                        a3 = torch.argmax(Q3).item()
                        a4 = torch.argmax(Q4).item()

                        probs = torch.nn.functional.softmax(torch.tensor([Q4[:, a1], Q4[:, a2],Q4[:, a3], Q4[:, a4]])/max(1e-4, max(Q4[:, a1].item(), Q4[:, a2].item(), Q4[:, a3].item(), Q4[:, a4].item()))*self.reset_ww, dim=0)
                        chosen_index = torch.distributions.Categorical(probs).sample()
                        if chosen_index == 0:
                            action = a1
                        elif chosen_index == 1:
                            action = a2
                        elif chosen_index == 2:
                            action = a3
                        else:
                            action = a4
                    elif self.reset_time == 3:
                        a1 = torch.argmax(Q).item()
                        a2 = torch.argmax(Q2).item()
                        a3 = torch.argmax(Q3).item()
                        a4 = torch.argmax(Q4).item()

                        probs = torch.nn.functional.softmax(torch.tensor([Q[:, a1], Q[:, a2],Q[:, a3], Q[:, a4]])/max(1e-4, max(Q[:, a1].item(), Q[:, a2].item(), Q[:, a3].item(), Q[:, a4].item()))*self.reset_ww, dim=0)
                        chosen_index = torch.distributions.Categorical(probs).sample()
                        if chosen_index == 0:
                            action = a1
                        elif chosen_index == 1:
                            action = a2
                        elif chosen_index == 2:
                            action = a3
                        else:
                            action = a4

                else:
                    action = torch.argmax(Q).item() 
        # action = torch.argmax(Q).item()
    else:
        action = random.randrange(self.n_actions)
    return action, new_hidden_states


def select_action(
    self,
    exploration_type,
    preprocessed_obs,
    num_frames,
    hidden_states=None,
):
    sample = random.random()
    eps_threshold = 0
    if num_frames is not None:
        eps_threshold = max(self.epsilon - (self.epsilon - self.eps_end) / self.eps_decay_time_steps * num_frames, self.eps_end)
    self.tb_writer.add_scalar("epsilon", eps_threshold, num_frames)
    self.steps_done += 1

    exploration_case = {
        "greedy": greedy_action,
        "epsilon": epsilon_action,
    }

    return exploration_case[exploration_type](
        self,
        sample,
        eps_threshold,
        preprocessed_obs,
        num_frames,
        hidden_states,
    )
