import copy
import math
from collections import defaultdict
import torch
import torch.multiprocessing as mp
from torch.multiprocessing import Queue

import federated
from .base import *
import utils

# incomplete code for multi-gpu trainig

# def _convert_list_to_dict(list_data):
#     dict_data = {i: x for i, x in enumerate(list_data)}
#     return dict_data

# class MMTController(FederatedController):
#     def __init__(self, *args, **kwargs):
#         super(MMTController, self).__init__(*args, **kwargs)
#         self.parent_tree, self.data_tree = self.construct_tree()
#         self.root_node_idx = max(self.parent_tree.values())
#         self.subfl_models = {id: None for id in self.data_tree.keys()}
#         self.subfl_accu_rounds = {id: 0 for id in self.data_tree.keys()}
#         self.num_gpus = torch.cuda.device_count()
#         self.device_memory = {f'cuda:{i}': torch.cuda.get_device_properties(i).total_memory for i in range(self.num_gpus)}
    
#     def train(self, num_rounds, tokenizer):
#         super().train(num_rounds, tokenizer)   # update server model by training the federation
#         self.subfl_models[self.root_node_idx] = self.server_model    # assign trained server model to the root node
        
#         # accumulate training rounds for lazy training of sub-fl models
#         for id in self.data_tree.keys():
#             if id != self.root_node_idx:
#                 self.subfl_accu_rounds[id] += num_rounds

#         # We don't actually train the sub-FL models here, just accumulate rounds
#         # The actual training will happen when needed (e.g., in the leave method)

#     def train_subfl_model(self, node_id, initial_device):
#         devices = [f'cuda:{i}' for i in range(self.num_gpus)]
#         start_index = devices.index(initial_device)
        
#         for i in range(self.num_gpus):
#             device = devices[(start_index + i) % self.num_gpus]
#             print(f"Attempting to train sub-FL model {node_id} on {device}")
            
#             try:
#                 available_memory = torch.cuda.memory_allocated(device)
#                 if available_memory > self.device_memory[device] * 0.9:
#                     print(f"Device {device} is near capacity, trying next device")
#                     continue

#                 if self.subfl_models[node_id] is None:
#                     self.subfl_models[node_id] = copy.deepcopy(self.init_server_model)
#                 model = self.subfl_models[node_id].to(device)
                
#                 self.config.federated.num_rounds = self.subfl_accu_rounds[node_id]
                
#                 federated.train_fed(
#                     model, 
#                     _convert_list_to_dict(self.data_tree[node_id]), 
#                     eval_loaders=self.eval_loaders,
#                     config=self.config,
#                     do_logging=False,
#                     device=device,
#                     tokenizer=self.tokenizer
#                 )
#                 self.subfl_accu_rounds[node_id] = 0
#                 self.subfl_models[node_id] = model.cpu()
#                 return
#             except RuntimeError as e:
#                 if "out of memory" in str(e):
#                     print(f"Device {device} out of memory, trying next device")
#                 else:
#                     raise
        
#         print(f"Unable to train sub-FL model {node_id} on any available device due to memory constraints")

#     def leave(self, unlearned_client):
#         super().leave(unlearned_client)

#         self.client_loaders.pop(unlearned_client)

#         removed_parents = self.remove_influenced_parents(unlearned_client)
#         print("parent_unlearned_tree:", self.parent_tree)
#         print("removed_parents:", removed_parents)
#         for p in removed_parents:
#             self.data_tree.pop(p)
#             self.subfl_models.pop(p)
        
#         children_tree = defaultdict(list)
#         for c in self.parent_tree:
#             p = self.parent_tree[c]
#             children_tree[p].append(c)
#         print("children_tree:", children_tree)

#         aggr_nodes = []
#         for p in removed_parents: 
#             aggr_nodes.extend(children_tree.get(p, []))
#         print("aggr_nodes:", aggr_nodes)

#         if self.parallel:
#             processes = []
#             for i, node_id in enumerate(aggr_nodes):
#                 gpu_id = i % self.num_gpus
#                 initial_device = f'cuda:{gpu_id}'
#                 p = utils.create_mp_process(
#                     self.train_subfl_model,
#                     node_id,
#                     initial_device,
#                 )
#                 p.start()
#                 processes.append(p)
            
#             for p in processes:
#                 p.join()
#         else:
#             for node_id in aggr_nodes:
#                 self.train_subfl_model(node_id, self.device)

#         aggr_grads = []
#         aggr_weighted = []
#         for node_id in aggr_nodes:
#             grad = federated.compute_grad_update(self.server_model, self.subfl_models[node_id])
#             aggr_grads.append(grad)
#             num_samples = sum(len(x) for x in self.data_tree[node_id])
#             aggr_weighted.append(num_samples)
#         fed_samples = sum(aggr_weighted)
#         aggr_weighted = [x / fed_samples for x in aggr_weighted]

#         self.server_model = federated.aggr_grads(self.server_model, aggr_grads, aggr_weighted)

#         self.root_node_idx = max(self.parent_tree.values())
#         self.subfl_models[self.root_node_idx] = self.server_model 
#         self.data_tree[self.root_node_idx] = []
#         for node_id in aggr_nodes:
#             self.parent_tree[node_id] = self.root_node_idx
#             self.data_tree[self.root_node_idx].extend(self.data_tree[node_id])
        
#         self.evaluate()
    
#     def join(self, client_id, dataloader, attach_to):
#         super().join(client_id, dataloader)

#         self.parent_tree[client_id] = attach_to
#         self.subfl_models[client_id] = None 
#         self.data_tree[client_id] = [self.client_loaders[client_id]]
#         node = client_id
#         while node in self.parent_tree:
#             parent_node = self.parent_tree[node] 
#             self.data_tree[parent_node].extend(self.data_tree[node])
#             node = parent_node
        
#         self.evaluate()
        
#     def construct_tree(self):
#         num_clients = self.config.federated.num_clients
#         tree_height = int(math.ceil(math.log2(num_clients))) + 1
#         parent_tree = {}
#         data_tree = {}
#         for level in range(tree_height):
#             if level == 0:
#                 no_child_nodes = 2 ** (tree_height - 1)
#                 for i in range(no_child_nodes):
#                     if i in self.client_loaders.keys():
#                         data_tree[i] = [self.client_loaders[i]]
#                     else:
#                         data_tree[i] = []
#                 child_offset_idx = no_child_nodes
#                 curr_node_idx = no_child_nodes
#             else:
#                 for i in range(no_child_nodes // 2):
#                     left_child = curr_node_idx - child_offset_idx
#                     right_child = left_child + 1
#                     parent_tree[left_child] = curr_node_idx
#                     parent_tree[right_child] = curr_node_idx
#                     data_tree[curr_node_idx] = []
#                     data_tree[curr_node_idx].extend(data_tree[left_child])
#                     data_tree[curr_node_idx].extend(data_tree[right_child])
#                     curr_node_idx += 1
#                     child_offset_idx -= 1
#                 no_child_nodes //= 2

#         return parent_tree, data_tree

#     def remove_influenced_parents(self, unlearned_node):
#         node = unlearned_node
#         removed_parents = []
#         while node in self.parent_tree:
#             parent_node = self.parent_tree[node]
#             removed_parents.append(parent_node)
#             self.parent_tree.pop(node)
#             node = parent_node
#         return removed_parents


import copy
import math
from collections import defaultdict

import federated
from .base import *


def _convert_list_to_dict(list_data):
    dict_data = {i: x for i, x in enumerate(list_data)}
    return dict_data

class MMTController(FederatedController):
    def __init__(self, *args, k: int = 2, **kwargs):
        super(MMTController, self).__init__(*args, **kwargs)
        if k == -1:
            self.k = len(self.client_loaders)
        else:
            self.k = k
        self.parent_tree, self.root_node_idx = self.construct_tree()
        self.sub_models = {i: None for i in range(self.root_node_idx + 1)}
        self.sub_accu_rounds = {i: 0 for i in range(self.root_node_idx + 1)}
    
    def train(self, num_rounds):
        super().train(num_rounds)   # update server model by training the federation
        self.sub_models[self.root_node_idx] = self.server_model    # assign trained server model to the root node
        # accumulate training epochs for lazy training of sub-fl models
        for i in self.parent_tree.keys():
            self.sub_accu_rounds[i] += num_rounds

    def leave(self, unlearned_client) :
        super().leave(unlearned_client)
        self.client_loaders.pop(unlearned_client)

        # remove all ancesters of the unlearned client
        removed_parents = self.remove_influenced_parents(unlearned_client)
        print("parent_unlearned_tree:", self.parent_tree)
        print("removed_parents:", removed_parents)
        for p in removed_parents:
            self.sub_models.pop(p)
            self.sub_accu_rounds[p] = 0
        
        # trace back the remaining children of the removed ancestors
        children_tree = defaultdict(list)
        for c in self.parent_tree:
            p = self.parent_tree[c]
            children_tree[p].append(c)
        print("children_tree:", children_tree)

        # collect list of remaining children/sub-FL to be retrained
        aggr_nodes = []
        for p in removed_parents: 
            aggr_nodes.extend(children_tree.get(p, []))
        print("aggr_nodes:", aggr_nodes)

        # retrain sub-FL models
        if self.config.parallel.do_parallel:
            processes = []

        # do lazy training of sub-fl models (given cached number of epochs)
        sub_clients_weight = []
        for i, node_idx in enumerate(aggr_nodes):
            sub_client_loaders = self.get_recursive_data(node_idx, children_tree)
            sub_client_loaders = _convert_list_to_dict(sub_client_loaders)
            num_train_samples = sum(len(x["train"].dataset) for x in sub_client_loaders.values())
            sub_clients_weight.append(num_train_samples)
            if self.sub_models[node_idx] is None:  # the model isn't initialized
                self.sub_models[node_idx] = copy.deepcopy(self.init_server_model)
            original_config_round = self.config.federated.num_rounds
            self.config.federated.num_rounds = self.sub_accu_rounds[node_idx]
            self.sub_accu_rounds[node_idx] = 0      # reset number of accumulated rounds

            if self.config.parallel.do_parallel and self.config.parallel.do_multi_gpus:
                device_id = int(i / self.config.parallel.num_clients_per_device)
                if device_id >= torch.cuda.device_count():
                    # device = self.device
                    raise "Not enough devices."
                else:
                    device = torch.device("cuda", device_id)
            else:
                device = self.device # doesn't change device

            # print(f"Putting sub FL model {node_idx} to {device}")
            if self.config.parallel.do_parallel:
                print(f"Started process for node {node_idx}")
                self.sub_models[node_idx].share_memory()
                p = utils.create_mp_process(federated.train_fed,
                                            self.sub_models[node_idx],
                                            sub_client_loaders,
                                            self.config,
                                            eval_fn=None,
                                            device=device,
                                            is_llm=self.is_llm,
                                            tokenizer=self.tokenizer,
                                            )
                p.start()
                processes.append(p)
            else:
                print(f"Train sub-FL model {node_idx}")
                federated.train_fed(self.sub_models[node_idx], 
                                    sub_client_loaders, 
                                    self.config, 
                                    eval_fn=None, 
                                    device=device, 
                                    is_llm=self.is_llm, 
                                    tokenizer=self.tokenizer,
                                    )
            
        if self.config.parallel.do_parallel:
            count = 1
            for p in processes:
                p.join()
                print(f"Process {count} joined")
                count += 1

        self.config.federated.num_rounds = original_config_round
        print('sub_accu_rounds:', self.sub_accu_rounds)
        server_device = next(self.server_model.parameters()).device
        sub_clients_weight = [x / sum(sub_clients_weight) for x in sub_clients_weight]
        for node_idx in aggr_nodes:
            self.sub_models[node_idx].to(server_device)
        print("Sub_clients_weight:", sub_clients_weight)
        self.server_model = federated.aggr_params(self.server_model, 
                                                  [self.sub_models[node_idx] for node_idx in aggr_nodes],
                                                  sub_clients_weight,
                                                  )

        # create a new root node (to replace the removed root node) 
        self.root_node_idx = max(self.parent_tree.values())
        self.sub_models[self.root_node_idx] = self.server_model 
        for node_idx in aggr_nodes:
            self.parent_tree[node_idx] = self.root_node_idx
        
    
    def join(self, client_id, dataloader, attach_to):
        super().join(client_id, dataloader)
        self.parent_tree[client_id] = attach_to
        self.subfl_models[client_id] = None 
        self.data_tree[client_id] = [self.client_loaders[client_id]]
        node = client_id
        while node in self.parent_tree:
            parent_node = self.parent_tree[node] 
            self.data_tree[parent_node].extend(self.data_tree[node])
            node = parent_node

    def construct_tree(self):
        """
        Construct a k-ary tree. 
        Leaf nodes are all clients in the federated setting and dummy nodes (total leaf nodes = k^h), where h is tree height.
        """
        num_clients = self.config.federated.num_clients
        # num_clients = len(self.client_loaders)
        tree_height = int(math.ceil(math.log(num_clients, self.k))) + 1
        parent_tree = {}    # {node: immediate parent node}
        for level in range(tree_height):    # tree is built bottom-up
            if level == 0:
                num_level_nodes = self.k ** (tree_height - 1)
                self.client_mapping = [None for _ in range(num_level_nodes)]
                for i, c_idx in enumerate(self.client_loaders):
                    self.client_mapping[i] = c_idx
                child_offset_idx = num_level_nodes
                curr_node_idx = num_level_nodes
            else:
                for i in range(num_level_nodes // self.k):
                    leftmost_c_idx = curr_node_idx - child_offset_idx
                    for child_count in range(0, self.k):
                        c_idx = leftmost_c_idx + child_count
                        parent_tree[c_idx] = curr_node_idx
                    curr_node_idx += 1
                    child_offset_idx -= (self.k - 1)
                num_level_nodes //= self.k
        root_node_idx = max(parent_tree.values())
        return parent_tree, root_node_idx

    def remove_influenced_parents(self, unlearned_node):
        """
        Remove all parents of the unlearned_node (immediate link in parent_tree) from bottom to top. 
        """
        node_idx = self.client_mapping.index(unlearned_node)
        removed_parents = []
        while node_idx in self.parent_tree:
            parent_node = self.parent_tree[node_idx]
            removed_parents.append(parent_node)
            self.parent_tree.pop(node_idx)
            node_idx = parent_node
        return removed_parents

    def get_recursive_data(self, node_idx, children_tree):
        if node_idx is None:
            return []
        if node_idx in self.client_mapping:
            client_idx = self.client_mapping[node_idx]
            return [self.client_loaders[client_idx]]
        data = []
        for child_idx in children_tree[node_idx]:
            data.extend(self.get_recursive_data(child_idx, children_tree))
        return data
