import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.optimizer import Optimizer
import torch.distributed as dist
import time
import random

import sys
sys.path.append("../")
from ring import *
from exponential_graph import *
from one_peer_exponential_graph import *
from adic_graph import *
from one_peer_undirected_equidyn import *
from one_peer_directed_equidyn import *
from directed_equitopo import *
from undirected_equitopo import *
from torus import *
from adic_graph_wo_permutation import *
from one_peer_undirected_equidyn import *


class ParallelActivationScheme():
    def __init__(self, n_nodes, graph, seed=0):
        self.n_nodes = n_nodes
        self.n_active_nodes = []
        self.graph_list = []
        print(self.n_nodes)
        for i in range(int(math.log2(self.n_nodes+1))):
            self.n_active_nodes.append(2**i)

            if graph == "ring":
                self.graph_list.append(Ring(2**i))
            elif graph == "exp":
                self.graph_list.append(ExponentialGraph(2**i))
            elif graph ==  "one_peer_adic_wo_permutation":
                self.graph_list.append(AdicGraphWoPermutation(2**i, max_degree=1))
            elif graph == "one_peer_undirected_equidyn":
                self.graph_list.append(OnePeerUndirectedEquiDyn(2**i))
            else:
                raise Exception(f"{graph} is invalid input for ParallelActivation Scheme")
        self.seed = seed
        self.random_inst = random.Random(self.seed)
        
        self.current_active_nodes = []
        self.next_active_nodes = []

        self.current_all_active_nodes = []
        self.next_all_active_nodes = []
        
        # initialize
        self.update()
        self.update()

    def update(self):
        self.current_active_nodes = self.next_active_nodes
        self.current_all_active_nodes = self.next_all_active_nodes
        
        # shuffle
        node_list = list(range(self.n_nodes))
        node_list = self.random_inst.sample(node_list, self.n_nodes)

        self.next_active_nodes = []
        self.next_all_active_nodes = []
        for i in range(int(math.log2(self.n_nodes+1))):
            self.next_active_nodes.append(node_list[2**i:2**(i+1)])
            self.next_all_active_nodes += node_list[2**i:2**(i+1)]

        for graph in self.graph_list:
            graph.update()
        
    def is_active(self, node_id):
        return node_id in self.current_all_active_nodes
    
    def is_next_active(self, node_id):
        return node_id in self.next_all_active_nodes
    
    def get_token_location(self, graph_id, token_id):
        return self.current_active_nodes[graph_id][token_id]
    
    def get_next_token_location(self, graph_id, token_id):
        return self.next_active_nodes[graph_id][token_id]
    
    def get_token_id(self, node_id):
        if self.is_active(node_id):
            for graph_id in range(len(self.current_active_nodes)):
                if node_id in self.current_active_nodes[graph_id]:
                    return graph_id, self.current_active_nodes[graph_id].index(node_id)
        else:
            return None, None

    def get_next_token_id(self, node_id):
        if self.is_next_active(node_id):
            for graph_id in range(len(self.next_active_nodes)):
                if node_id in self.next_active_nodes[graph_id]:
                    return graph_id, self.next_active_nodes[graph_id].index(node_id)
        else:
            return None, None
        
    def get_out_neighbors(self, node_id):
        graph_id, token_id = self.get_token_id(node_id)
        out_token_ids = self.graph_list[graph_id].get_out_neighbors(token_id)
        return {self.get_next_token_location(graph_id, out_token_id) : out_token_ids[out_token_id] for out_token_id in out_token_ids}
        
    def get_in_neighbors(self, node_id):
        graph_id, token_id = self.get_next_token_id(node_id)
        in_token_ids = self.graph_list[graph_id].get_in_neighbors(token_id)
        return {self.get_token_location(graph_id, in_token_id) : in_token_ids[in_token_id] for in_token_id in in_token_ids}


class ParallelTeleportationOptimizer(Optimizer):
    def __init__(self, params, activation_scheme, node_id: int, lr=1e-5, beta=0.9, device="cpu"):

        self.node_id = node_id
        self.activation_scheme = activation_scheme
        self.device = device
        self.lr = lr
        self.n_sent_params = 0

        self.avg_p = 0.0
        self.graph_id = self.activation_scheme.get_token_id(self.node_id)[0]
        
        defaults = dict(lr=lr, beta=beta)
        super(ParallelTeleportationOptimizer, self).__init__(params, defaults)
        
        # generate initial dual variables.
        for group in self.param_groups:
            group["momentum"] = []
            
            for p in group["params"]:                    
                group["momentum"].append(torch.zeros_like(p, device=self.device))


    def is_active(self):
        return self.activation_scheme.is_active(self.node_id)

    def get_token_id(self):
        return self.activation_scheme.get_token_id(self.node_id)
    
    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        
        if self.activation_scheme.is_active(self.node_id):
            for group in self.param_groups:
                lr = group['lr']
                beta = group['beta']
                
                for p, momentum, in zip(group['params'], group["momentum"]):
                    
                    # update model parameters W_i
                    momentum.data = beta * momentum + p.grad 
                    p.data = p.data - lr * momentum
                    
                if closure is not None:
                    loss = closure()

        self.communicate()
        self.update_momentum()
        self.activation_scheme.update()
        self.graph_id = self.activation_scheme.get_token_id(self.node_id)[0] 
        return loss


    @torch.no_grad()
    def communicate(self):
        
        if self.activation_scheme.is_active(self.node_id):
            out_neighbors = self.activation_scheme.get_out_neighbors(self.node_id)
        else:
            out_neighbors = []
            
        if self.activation_scheme.is_next_active(self.node_id):
            in_neighbors = self.activation_scheme.get_in_neighbors(self.node_id)
        else:
            in_neighbors = []
            
        task_list = []
        recieved_params = {}
        
        
        for node_id in out_neighbors:
            if node_id == self.node_id:
                continue            
            #print(f"{self.node_id}({self.token_id})->{node_id}({self.activation_scheme.get_next_token_id(node_id)})")
            task_list += self.send_param(node_id)

        ## revieved_params[node_id]ではなく、token_idの方が使いがってがよい？
        for node_id in in_neighbors:
            if node_id == self.node_id:
                continue
            #print(f"{self.node_id}({self.activation_scheme.get_next_token_id(self.node_id)})<-{node_id}({self.activation_scheme.get_token_id(node_id)})")
            tasks, params = self.recv_param(node_id)
            task_list += tasks
            recieved_params[node_id] = params
                        
        for task in task_list:
            task.wait()

        if self.activation_scheme.is_next_active(self.node_id):
            self.average_param(recieved_params)
        
        
    @torch.no_grad()
    def send_param(self, node_id):
        task_list = []
        
        for group in self.param_groups:
            for i, p in enumerate(group["params"]):
                task_list.append(dist.isend(tensor=p.to("cpu"), dst=node_id, tag=i))
                self.n_sent_params += torch.numel(p)

        return task_list

    
    @torch.no_grad()
    def recv_param(self, node_id):
        task_list = []
        recieved_params = []
        
        for group in self.param_groups:        
            for i, p in enumerate(group["params"]):    
                tmp = torch.zeros_like(p, device="cpu")
                task_list.append(dist.irecv(tensor=tmp, src=node_id, tag=i))
                recieved_params.append(tmp)
                
        return task_list, recieved_params
    
    @torch.no_grad()
    def average_param(self, recieved_params):
        in_neighbors = self.activation_scheme.get_in_neighbors(self.node_id)
        
        for group in self.param_groups:            
            for i, p in enumerate(group["params"]):
                if self.node_id in in_neighbors:
                    p.data = p.data * in_neighbors[self.node_id]
                else:
                    p.data = p.data * 0.0
                    
                for node_id in in_neighbors.keys():
                    if node_id == self.node_id:
                        continue
                    p.data += in_neighbors[node_id] * recieved_params[node_id][i].to(self.device)

    @torch.no_grad()
    def update_momentum(self):        
        task_list = []
        
        if self.activation_scheme.is_active(self.node_id):
            out_graph_id, out_token_id = self.activation_scheme.get_token_id(self.node_id)
            out_node_id = self.activation_scheme.get_next_token_location(out_graph_id, out_token_id)

            if self.node_id != out_node_id:
                task_list += self.send_momentum(out_node_id)

        recv_flag = False
        if self.activation_scheme.is_next_active(self.node_id):
            in_graph_id, in_token_id = self.activation_scheme.get_next_token_id(self.node_id)
            in_node_id = self.activation_scheme.get_token_location(in_graph_id, in_token_id)

            if self.node_id != in_node_id:
                task, recieved_params = self.recv_momentum(in_node_id)
                task_list += task
                recv_flag = True
                
        for task in task_list:
            task.wait()

        if recv_flag:
            self.replace_momentum(recieved_params)
            
    @torch.no_grad()
    def send_momentum(self, node_id):
        task_list = []
        
        for group in self.param_groups:
            for i, momentum in enumerate(group["momentum"]):
                task_list.append(dist.isend(tensor=momentum.to("cpu"), dst=node_id, tag=i))
                self.n_sent_params += torch.numel(momentum)

        return task_list
        
    @torch.no_grad()
    def recv_momentum(self, node_id):
        task_list = []
        recieved_params = []
        
        for group in self.param_groups:
            for i, momentum in enumerate(group["momentum"]):
                tmp = torch.zeros_like(momentum, device="cpu")
                task_list.append(dist.irecv(tensor=tmp, src=node_id, tag=i))
                recieved_params.append(tmp)
                
        return task_list, recieved_params

    
    @torch.no_grad()
    def replace_momentum(self, recieved_params):        
        for group in self.param_groups:            
            for i, momentum in enumerate(group["momentum"]):
                momentum.data = momentum.data * 0.0
                momentum.data += recieved_params[i].to(self.device)

    
    @torch.no_grad()
    def print_param(self):
        for group in self.param_groups:        
            for p in group["params"]:
                print(p)
