import copy
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast
from collections import Counter
from torch.nn.utils import parameters_to_vector, vector_to_parameters

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action, init_action_bias=0.0):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, action_dim)
        )
        final_layer = self.net[-1]
        nn.init.uniform_(final_layer.weight, -3e-3, 3e-3)
        final_layer.bias.data.fill_(init_action_bias)
        self.max_action = max_action
        '''
        self.l1 = nn.Linear(state_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, action_dim)
        self.max_action = max_action
        if offset is not None:
            with torch.no_grad():
                self.l3.bias += offset
        '''
        
    def forward(self, state):
        #a = F.relu(self.l1(state))
        #a = F.relu(self.l2(a))
        #return self.max_action * torch.tanh(self.l3(a))
        return self.max_action * torch.tanh(self.net(state))

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        
        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)
        
        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        
        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1

class AEAP(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        discount=0.99,
        tau=0.005,
        policy_noise=0.2,
        noise_clip=0.5,
        policy_freq=2
    ):
        initial_actor_num = 3
        self.actors = nn.ModuleList()
        # Define distinct offsets in action space
        for i in range(initial_actor_num):
            c_i = -max_action + 2 * max_action * (i + 1) / (initial_actor_num + 1)
            init_bias = np.arctanh(c_i / max_action)
            #print(f'init_bias: {init_bias}')
            self.actors.append(Actor(state_dim, action_dim, max_action, init_action_bias=init_bias))
        '''
        offsets = [
            torch.full((action_dim,), -0.5*max_action),
            torch.zeros(action_dim),
            torch.full((action_dim,), 0.5*max_action),
        ]
        
        self.actors = nn.ModuleList([
            Actor(state_dim, action_dim, max_action, offset=offsets[i % len(offsets)])
            for i in range(initial_actor_num)
        ]).to(device)
        for actor in self.actors:
            for param in actor.l3.parameters():
                nn.init.normal_(param, mean=0, std=1e-1)
        
        for i, actor in enumerate(self.actors):
            for name, param in actor.named_parameters():
                if 'weight' in name:
                    nn.init.orthogonal_(param, gain=1.0 + 0.1*i)
                elif 'bias' in name:
                    nn.init.constant_(param, 0.1 * (i + 1))
        
        for i, actor in enumerate(self.actors):
            for name, param in actor.named_parameters():
                if 'weight' in name:
                    if i % 3 == 0:  # First group
                        nn.init.kaiming_normal_(param, mode='fan_out')
                    elif i % 3 == 1:  # Second group
                        nn.init.xavier_uniform_(param)
                    else:  # Third group
                        nn.init.orthogonal_(param, gain=1.414)
                elif 'bias' in name:
                    if i % 2 == 0:  # Even indexed actors
                        nn.init.constant_(param, 0.05)
                    else:  # Odd indexed actors
                        nn.init.normal_(param, mean=0.1, std=0.01)
        '''

        self.active_actors = torch.ones(initial_actor_num, dtype=torch.bool, device=device)
        lr = torch.tensor(5e-4, device=device)
        self.actor_optimizers = [torch.optim.Adam(actor.parameters(), lr=lr) for actor in self.actors]
        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
        
        self.action_dim = action_dim
        self.max_action = max_action
        self.criteria = self.max_action * math.sqrt(self.action_dim)
        print('self.criteria: ', self.criteria)
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.total_it = 0
        self.freeze_steps = 20000 
        self.prune_interval = 10000
        self.backoff_factor = 2
        self.chosen_idx_counter = Counter()

    def convert_flat_idx_to_pair(self, flat_idx, n_actors):
        pairs = [(i, j) for i in range(n_actors) for j in range(i+1, n_actors)]
        return pairs[flat_idx]

    @torch.no_grad()
    def calculate_actor_distance(self, next_state):
        if self.total_it >= self.freeze_steps and self.total_it % self.prune_interval == 0:
        #if self.total_it % 20000 == 0:
            #print('Im here at pruing frequency')
            active_indices = torch.where(self.active_actors)[0]
            next_actions = torch.stack([self.actors[i](next_state) for i in active_indices]).permute(1, 0, 2)
            # noise = (torch.randn_like(next_actions) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            # next_actions = (next_actions + noise).clamp(-self.max_action, self.max_action)
            q_values = torch.stack([
                torch.min(*self.critic_target(next_state, next_actions[:,i,:])) 
                for i in range(next_actions.shape[1]) ], dim=1)
            distances = torch.cdist(next_actions, next_actions, p=2)
            # print(f'distances: {distances}')
            n_actors = distances.shape[1]
            idx = torch.triu_indices(n_actors, n_actors, offset=1)
            pairwise_distances = distances[:, idx[0], idx[1]]
            # print(f'next_actions: {next_actions}')
            # print(f'pairwise_distances: {pairwise_distances}')
            # mean_distances = pairwise_distances.mean(dim=0)
            # print('mean_distances: ', mean_distances)
            max_distances = pairwise_distances.max(dim=0)[0]
            #print('max_distances: ', max_distances)
            mean_q_values = q_values.mean(dim=0)
            max_q = abs(mean_q_values.max())
            min_q = abs(mean_q_values.min())
            #print('mean_q_values:', mean_q_values)
            if min_q < 0.8 * max_q:
                actor_to_elimi = torch.argmin(mean_q_values)
                print('actor_to_elimi: ', actor_to_elimi)
                active_to_original = [i for i, active in enumerate(self.active_actors) if active]
                self.active_actors[active_to_original[actor_to_elimi]] = False
                self.prune_interval = self.prune_interval * self.backoff_factor
                print(f'self.prune_interval: {self.prune_interval}')
            elif torch.any(max_distances < self.criteria):
                min_dist_idx = torch.argmin(max_distances)
                active_to_original = [i for i, active in enumerate(self.active_actors) if active]
                actor1, actor2 = self.convert_flat_idx_to_pair(min_dist_idx, len(active_to_original))
                print('max_distances: ', max_distances)
                print(f'actor1: {actor1}, actor2: {actor2}')
                q1 = q_values[:, actor1].mean()
                q2 = q_values[:, actor2].mean()
                original_idx = active_to_original[actor1 if q1 < q2 else actor2]
                # print(f'original_idx: {original_idx}')
                self.active_actors[original_idx] = False
                self.prune_interval = self.prune_interval * self.backoff_factor
                print(f'self.prune_interval: {self.prune_interval}')
        active_actors = [actor for i, actor in enumerate(self.actors) if self.active_actors[i]]
        chosen_actor = random.choice(active_actors)
        return chosen_actor(next_state)

    def select_action(self, state, is_eval = False):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        active_indices = torch.where(self.active_actors)[0]
        if is_eval:
            with torch.no_grad():
                active_actions = torch.stack([self.actors[idx](state) for idx in active_indices])
                q_values = torch.tensor([
                    torch.min(*self.critic_target(state, active_actions[i])).item()
                    for i in range(len(active_indices))
                ])
            chosen_idx = active_indices[torch.argmax(q_values).item()]
        else:
            # chosen_idx = active_indices[torch.argmax(q_values).item()]
            # probs = torch.softmax(q_values, dim=0).cpu().numpy()
            # chosen_idx = np.random.choice(active_indices.cpu().numpy(), p=probs)
            # self.chosen_idx_counter[chosen_idx] += 1
            # print('chosen_idx_counter', self.chosen_idx_counter)
            chosen_idx = random.choice(active_indices)
        chosen_actor = self.actors[chosen_idx]
        return chosen_actor(state).cpu().data.numpy().flatten()

    def train(self, replay_buffer, batch_size=256):
        self.total_it += 1
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

        with torch.no_grad():
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.calculate_actor_distance(next_state) + noise).clamp(-self.max_action, self.max_action)
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * target_Q

        current_Q1, current_Q2 = self.critic(state, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        if self.total_it % self.policy_freq == 0:
            
            active_indices = torch.where(self.active_actors)[0]
            for idx in active_indices:
                self.actor_optimizers[idx].zero_grad()
                actions_real = self.actors[idx](state)
                actor_loss = -self.critic_target.Q1(state, actions_real).mean()
                #actor_loss += 0.1 * (actions_real / self.max_action).pow(2).mean()
                #  start to update the network
                self.actor_optimizers[idx].zero_grad()
                actor_loss.backward()
                self.actor_optimizers[idx].step()
            '''
            active_indices = torch.where(self.active_actors)[0]
            chosen_idx = random.choice(active_indices)
            self.actor_optimizers[chosen_idx].zero_grad()
            actor_loss = -self.critic_target.Q1(state, self.actors[chosen_idx](state)).mean()
            actor_loss.backward()
            self.actor_optimizers[chosen_idx].step()
            '''
            with torch.no_grad():
                for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
        
        for i, actor in enumerate(self.actors):
            torch.save(actor.state_dict(), filename + f"_actor_{i}")
            torch.save(self.actor_optimizers[i].state_dict(), filename + f"_actor_optimizer_{i}")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        for i, actor in enumerate(self.actors):
            actor.load_state_dict(torch.load(filename + f"_actor_{i}"))
            self.actor_optimizers[i].load_state_dict(torch.load(filename + f"_actor_optimizer_{i}"))
