import numpy as np
import random
import networkx as nx
import torch.nn as nn
from model import ResNet

from typing import List, Tuple

class MisInfoSpreadState:
    def __init__(self, node_states: np.ndarray, adjacency_matrix: np.ndarray, time_step: int):
        self.node_states = node_states
        self.adjacency_matrix = adjacency_matrix
        self.time_step = time_step

    def __hash__(self):
        return hash(str(self.node_states))

    def __eq__(self, other):
        return np.array_equal(self.node_states, other.node_states)


class MisInfoSpread:
    INFECTED_VALUE = -1

    def __init__(self, num_nodes, max_time_steps, trust_on_source=1, 
                 positive_info_threshold=0.95, count_infected_nodes = 1, count_actions = 1):
        self.num_nodes = num_nodes
        self.max_time_steps = max_time_steps

        self.trust_on_source = trust_on_source
        self.positive_info_threshold = positive_info_threshold
        self.count_infected_nodes = count_infected_nodes
        self.count_actions = count_actions

    def find_neighbor(self, state):
        neighbors = set()
        for i, node_value in enumerate(state.node_states):
            if node_value == self.INFECTED_VALUE:
                for j, connection in enumerate(state.adjacency_matrix[i]):
                    if connection != 0 and state.node_states[j] != self.INFECTED_VALUE and state.node_states[j] <= self.positive_info_threshold:
                        neighbors.add(j)
        return list(neighbors)

    def find_neighbor_batch(self, states):
        return [self.find_neighbor(state) for state in states]

    def step(self, state, action_list):
        next_state, _ = self.next_state([state], action_list)
        done = not self.find_neighbor(next_state[0]) or next_state[0].time_step >= self.max_time_steps
        return next_state[0], self.reward([state], [next_state[0]])[0], done

    def step_batch(self, states: List[MisInfoSpreadState], actions) -> Tuple[List[MisInfoSpreadState], List[float], List[bool]]:
        results = [self.step(state, action_list) for state, action_list in zip(states, actions)]
        next_states, rewards, dones = zip(*results)
        return list(next_states), list(rewards), list(dones)

    def find_connections(self, nodeID, state):
        connections = []
        for i, connection in enumerate(state.adjacency_matrix[nodeID]):
            if connection != 0:
                connections.append(i)
        return connections

    def next_state(self, states: List['MisInfoSpreadState'], action_list: List[int]) -> Tuple[List['MisInfoSpreadState'], List[float]]:
        next_states, costs = [], []
        for state in states:

            current_node_states = state.node_states.copy()
            adjacency_matrix = state.adjacency_matrix
            current_time_step = state.time_step

            if len(action_list) > 0:

                for action in action_list:
                    connected_nodes = self.find_connections(action, state)
                    opinion_value = 0
                    for node in connected_nodes:
                        opinion_value += current_node_states[node] * adjacency_matrix[action][node]
                    
                    # considering positive source node as connected node with trust_on_source value
                    current_node_states[action] = round((opinion_value + self.trust_on_source * 1), 2)
                    current_node_states[action] = max(min(current_node_states[action], 1), self.INFECTED_VALUE)     


                state.node_states = current_node_states
                neighbors = self.find_neighbor(state)

                for candidate_node in neighbors:
                    if candidate_node not in action_list:
                        connected_nodes = self.find_connections(candidate_node, state)
                        opinion_value = 0
                        for node in connected_nodes:
                            opinion_value += current_node_states[node] * adjacency_matrix[candidate_node][node]
                        
                        current_node_states[candidate_node] = round((opinion_value), 2)
                        current_node_states[candidate_node] = max(min(current_node_states[candidate_node], 1), self.INFECTED_VALUE)

            new_state = MisInfoSpreadState(current_node_states, adjacency_matrix, current_time_step + 1)
            next_states.append(new_state)
            costs.append(1)  # Add relevant cost calculation here

        return next_states, costs

    def reward(self, states: List['MisInfoSpreadState'], next_states: List['MisInfoSpreadState']) -> List[float]:
        rewards = []
        for state, next_state in zip(states, next_states):
            inf_nodes = -(next_state.node_states.count(self.INFECTED_VALUE) - state.node_states.count(self.INFECTED_VALUE))
            candidate_nodes = -len(self.find_neighbor(next_state))/self.num_nodes
            rewards.append(inf_nodes + candidate_nodes)

        return rewards

    def get_nnet_model(self) -> nn.Module:
        return ResNet(num_nodes=self.num_nodes, num_blocks=[3, 4, 4, 3])

    def generate_states(self, num_states):

        num_nodes = self.num_nodes
        if num_states <= 0 or num_nodes <= 0:
            raise ValueError("Number of states and nodes must be positive integers.")

        states = []
        for _ in range(num_states):
            state = [round(random.uniform(-0.5, 0.6), 2) for _ in range(num_nodes)]

            unique_nodes_to_infect = random.sample(range(num_nodes), self.count_infected_nodes)
            for node_index in unique_nodes_to_infect:
                state[node_index] = self.INFECTED_VALUE

            k = 3
            p = 0.4
            G = nx.watts_strogatz_graph(self.num_nodes, k, p)
            adjacency_matrix = nx.to_numpy_array(G).tolist()

            for i in range(len(adjacency_matrix)):
                for j in range(len(adjacency_matrix[i])):
                    if adjacency_matrix[i][j] == 1:
                        # numpy.random.uniform is used here to ensure 1 can be included
                        adjacency_matrix[i][j] = np.round(np.random.uniform(0, 1.0000001), 4)
                        adjacency_matrix[j][i] = adjacency_matrix[i][j]
                    
                    if adjacency_matrix[i][j] == 0:
                        adjacency_matrix[i][j] = 0
                        adjacency_matrix[j][i] = 0

                    if i == j:
                        adjacency_matrix[i][j] = np.round(np.random.uniform(0, 1.0000001), 4)

                # normalize adjacency_matrix[i] and round to 2 decimal places
                adjacency_matrix[i] = np.round(adjacency_matrix[i] / sum(adjacency_matrix[i]), 4)

            # Convert adjacency matrix back to NumPy matrix
            adjacency_matrix = np.array(adjacency_matrix)

            states.append(MisInfoSpreadState(state, adjacency_matrix, 0))

        return states
