import pandas as pd
import torch
from torch import Tensor
from typing import Tuple, List, Callable
import random
import warnings
warnings.filterwarnings("ignore")

import pdb
import time

from quantizer import *
from server import *
from client import *
from dataset_manager import *
from model_manager import *
from utils import *

from threading import Thread

class Trainer:
    
    SLOW_CLIENTS_RATIO = 0.3 #0.9#0.3 ## Proportion of slow clients in the population

    def __init__(self, algorithm, dataset_name, client_count, train_sets_list, test_set, 
                 local_step, group_count, quantizer, initial_model, log_period, 
                 gpu_ids, server_averaging=True, client_averaging=True):
        
        self.algorithm = algorithm
        self.dataset_name = dataset_name
        self.gpu_ids = gpu_ids
        self.initial_model = initial_model
        batch_size = get_batch_size(dataset_name)
        self.optimizer = get_optimizer(dataset_name)
        self.criterion = get_criterion(dataset_name)
        
        self.server = Server(initial_model, self.criterion, quantizer, gpu_id=gpu_ids[0])
        self.client_count = client_count ## This is N, as the number of agents/clients.
        self.clients = []
        
        self.LOCAL_STEP  = local_step  ## This is K, as the number of steps that a model 
                                       ## should take to get ready for another interaction.
        
        self.GROUP_COUNT = group_count ## This is S, as the number of models that server 
                                       ## should interact with them at each step.
        
        self.log_period = log_period
        test_batch_size = 1000 if dataset_name in ["cifar 10", "celeba"] else 2000 
        self.test_loader = data.DataLoader(test_set, batch_size = test_batch_size, 
                                           shuffle = True, num_workers=6)
        self.history = []
        self.last_tested = 0
        self.setup_clients(train_sets_list, batch_size, quantizer, initial_model, gpu_ids)
        
        self.server_averaging = server_averaging
        self.client_averaging = client_averaging
        
        
    def setup_clients(self, train_sets_list, batch_size, quantizer, initial_model, gpu_ids):
        shared_dataset =  len(train_sets_list) == 1
        gpu_count = len(self.gpu_ids)
        for i in range(self.client_count):
            if shared_dataset:
                sampler = torch.utils.data.distributed.DistributedSampler(train_sets_list[0], self.client_count, i, shuffle = True)
                dataloader_i = data.DataLoader(train_sets_list[0], batch_size=batch_size, 
                                           num_workers=2, sampler=sampler)
            else:
                dataloader_i = data.DataLoader(train_sets_list[i], batch_size=batch_size, 
                                           num_workers=2, shuffle=True)
            
            fast = i > int(Trainer.SLOW_CLIENTS_RATIO * self.client_count)
            client_i = Client(index = i, model = initial_model, optimizer = self.optimizer,
                              criterion = self.criterion, dataloader = dataloader_i,
                              quantizer = quantizer, gpu_id = self.gpu_ids[i%gpu_count], fast = fast)
            self.clients.append(client_i)
            print(f"Client {i+1} is added to the population.")
    
    def train_client(self, client, lr, p, server_SD, client_dictionary_mode, server_model_ratio_on_client):
        max_steps = self.LOCAL_STEP
        run_time, taken_steps = client.run_until(lr* (1 / (1 - p) if self.server_averaging else 1),
                                                 self.server.time, max_steps)
        client_dict = client.get_model_dictionary(quantized = True, mode=client_dictionary_mode)
        self.send_dict_to_server(client_dict)
        self.server.seen_local_steps += taken_steps

        client.average_with_server_SD(server_SD, server_model_ratio_on_client)
        client.time = self.server.time
        
    def train_quantized_fl(self, lr, time_limit):
        real_time = time.time()
        self.diverged = False
        p = 1 / (self.GROUP_COUNT + 1) ## Server ratio in averaging
        server_model_ratio_on_client = p if self.client_averaging else 1
        server_model_ratio_on_server = p if self.server_averaging else 0
        client_dictionary_mode = "state" #if server_averaging else "gradient"
        self.test()
#         lr_factor = 1
        while(self.server.time < time_limit):
            ### The following lines are for time-based lr-scheduler, to decrease 
            ### the LR during the training you can uncomment them.
#             if(self.server.time > time_limit * 0.6 and lr_factor == 1):
#                 lr_factor = 0.6
#                 lr *= 0.6
#             elif(self.server.time > time_limit * 0.8 and lr_factor == 0.6):
#                 lr_factor = 0.2
#                 lr *= 0.33
            
            if(self.server.time - self.last_tested >= self.log_period or 
               (self.server.time - self.last_tested >= 10 and self.server.time <= 100)):
                self.test()
                print(f"Real time: {time.time() - real_time}")
                real_time = time.time()

            interaction_group = random.sample(self.clients, self.GROUP_COUNT)
            server_SD = self.server.get_model_SD(quantized = True)
#             threads = []
            for client in interaction_group:
                ### The following lines are for multithread implementation of a single server step, to switch back to sequential form
                ### comment out the thread-related lines and uncomment the commented lines in the for loop.
#                 new_thread = Thread(target=self.train_client,args=(client, lr, p, server_SD, client_dictionary_mode, server_model_ratio_on_client))
#                 new_thread.start()
#                 threads.append(new_thread)
                max_steps = self.LOCAL_STEP ##random.randint(1, self.LOCAL_STEP)
                run_time, taken_steps = client.run_until(lr* (1 / (1 - p) if self.server_averaging else 1) ,
                                                         self.server.time, max_steps)

                client_dict = client.get_model_dictionary(quantized = True, mode=client_dictionary_mode)
                self.send_dict_to_server(client_dict)
                self.server.seen_local_steps += taken_steps
                
                client.average_with_server_SD(server_SD, server_model_ratio_on_client)
                client.time = self.server.time
#             for thread in threads:
#                 thread.join()
#             if server_averaging:
            self.server.average_received_SDs(server_model_ratio = server_model_ratio_on_server)
#             else:
#                 self.server.apply_received_GDs()
            self.server.interaction_count += 1
            self.server.time += Server.server_interaction_time
            self.server.time += Trainer.server_waiting_time
        self.test()
        
        return self.history
    
    def train_Fed_Avg(self, lr, time_limit):
        self.diverged = False
        self.test()
#         lr_factor = 1
        while(self.server.time < time_limit):
            ### The following lines are for time-based lr-scheduler, to decrease 
            ### the LR during the training you can uncomment them.
#             if(self.server.time > time_limit * 0.6 and lr_factor == 1):
#                 lr_factor = 0.6
#                 lr *= 0.6
#             elif(self.server.time > time_limit * 0.8 and lr_factor == 0.6):
#                 lr_factor = 0.2
#                 lr *= 0.33
            if(self.server.time - self.last_tested >= self.log_period or 
               (self.server.time - self.last_tested >= 10 and self.server.time <= 100)):
                self.test()

            server_SD = self.server.get_model_SD(quantized = True)
            interaction_group = random.sample(self.clients, self.GROUP_COUNT)
            run_times = []
            for client in interaction_group:
                client.load_SD(server_SD, quantized = True)
                run_time, taken_steps = client.take_step(self.LOCAL_STEP, lr)
                run_times.append(run_time)
                client_SD = client.get_model_SD(quantized = True)
                self.send_dict_to_server(client_SD)
                self.server.seen_local_steps += taken_steps
            self.server.time += max(run_times)
            self.server.time += Server.server_interaction_time
#             self.server.model = deepcopy(interaction_group[0].model)
            self.server.average_received_SDs(server_model_ratio = 0)
            self.server.interaction_count += 1
        self.test()

        return self.history    
    
    def train_Fed_Buff(self, lr, time_limit, buffersize=10):
        self.diverged = False
        self.test()
        indices = [i for i in range(self.client_count)]
        run_times = [0 for i in range(self.client_count)]
        while(self.server.time < time_limit):
            #print(indices)
            if(self.server.time - self.last_tested >= self.log_period or 
               (self.server.time - self.last_tested >= 10 and self.server.time <= 100)):
                self.test()

            server_SD = self.server.get_model_SD(quantized = True)
            for index in indices:
                client = self.clients[index]
                client.load_SD(server_SD, quantized = True)
                run_time, taken_steps = client.take_step(self.LOCAL_STEP, lr)
                self.server.seen_local_steps += taken_steps
                run_times[index] += run_time
            values, indices = torch.topk(-torch.tensor(run_times), buffersize)
            for index in indices:
                client = self.clients[index]
                run_times[index] = max(-values)
                client_SD = client.get_model_SD(quantized = True)
                self.send_dict_to_server(client_SD)
            self.server.time = max(-values)
            self.server.time += Server.server_interaction_time
#             self.server.model = deepcopy(interaction_group[0].model)
            self.server.average_received_SDs(server_model_ratio = 0)
            self.server.interaction_count += 1
        self.test()

        return self.history   
        
        
    def train_our(self, lr, time_limit):
        real_time = time.time()
        self.diverged = False
        p = 1 / (self.GROUP_COUNT + 1) ## Server ratio in averaging
        server_model_ratio_on_client = p if self.client_averaging else 1
        server_model_ratio_on_server = p if self.server_averaging else 0
        client_dictionary_mode = "state" #if server_averaging else "gradient"
        self.test()
#         lr_factor = 1
        while(self.server.time < time_limit):
            #print('OUR')
            ### The following lines are for time-based lr-scheduler, to decrease 
            ### the LR during the training you can uncomment them.
#             if(self.server.time > time_limit * 0.6 and lr_factor == 1):
#                 lr_factor = 0.6
#                 lr *= 0.6
#             elif(self.server.time > time_limit * 0.8 and lr_factor == 0.6):
#                 lr_factor = 0.2
#                 lr *= 0.33
            
            if(self.server.time - self.last_tested >= self.log_period or 
               (self.server.time - self.last_tested >= 10 and self.server.time <= 100)):
                self.test()
                print(f"Real time: {time.time() - real_time}")
                real_time = time.time()

            interaction_group = random.sample(self.clients, self.GROUP_COUNT)
            server_SD = self.server.get_model_SD(quantized = True)
#             threads = []
            for client in interaction_group:
                ### The following lines are for multithread implementation of a single server step, to switch back to sequential form
                ### comment out the thread-related lines and uncomment the commented lines in the for loop.
#                 new_thread = Thread(target=self.train_client,args=(client, lr, p, server_SD, client_dictionary_mode, server_model_ratio_on_client))
#                 new_thread.start()
#                 threads.append(new_thread)
                max_steps = self.LOCAL_STEP ##random.randint(1, self.LOCAL_STEP)
                run_time, taken_steps = client.run_until(lr* (1 / (1 - p) if self.server_averaging else 1) ,
                                                         self.server.time, max_steps)

                client_dict = client.get_model_dictionary(quantized = True, mode=client_dictionary_mode)
                self.send_dict_to_server(client_dict)
                self.server.seen_local_steps += taken_steps
                
                #client.average_with_server_SD(server_SD, server_model_ratio_on_client)
                client.time = self.server.time
#             for thread in threads:
#                 thread.join()
#             if server_averaging:
            self.server.average_received_SDs(server_model_ratio = server_model_ratio_on_server)
            server_SD = self.server.get_model_SD(quantized = True)
            for client in interaction_group:
                client.average_with_server_SD(server_SD, server_model_ratio_on_client)
#             else:
#                 self.server.apply_received_GDs()
            self.server.interaction_count += 1
            self.server.time += Server.server_interaction_time
            self.server.time += Trainer.server_waiting_time
        self.test()
        
        return self.history
    
    
    def train_langevin(self, lr, time_limit):
        real_time = time.time()
        self.diverged = False
        p = 1 / (self.GROUP_COUNT + 1) ## Server ratio in averaging
        server_model_ratio_on_client = p if self.client_averaging else 1
        server_model_ratio_on_server = p if self.server_averaging else 0
        client_dictionary_mode = "state" #if server_averaging else "gradient"
        self.test()
        while(self.server.time < time_limit):
            if(self.server.time - self.last_tested >= self.log_period or 
               (self.server.time - self.last_tested >= 10 and self.server.time <= 100)):
                self.test()
                print(f"Real time: {time.time() - real_time}")
                real_time = time.time()

            interaction_group = random.sample(self.clients, self.GROUP_COUNT)
            server_SD = self.server.get_model_SD(quantized = True)
            for client in interaction_group:
                max_steps = self.LOCAL_STEP
                run_time, taken_steps = client.run_until_langevin(lr* (1 / (1 - p) if self.server_averaging else 1) ,
                                                         self.server.time, max_steps)

                client_dict = client.get_model_dictionary(quantized = True, mode=client_dictionary_mode)
                self.send_dict_to_server(client_dict)
                self.server.seen_local_steps += taken_steps
                client.time = self.server.time
            self.server.average_received_SDs(server_model_ratio = server_model_ratio_on_server)
            server_SD = self.server.get_model_SD(quantized = True)
            for client in interaction_group:
                client.average_with_server_SD(server_SD, server_model_ratio_on_client)
            self.server.interaction_count += 1
            self.server.time += Server.server_interaction_time
            self.server.time += Trainer.server_waiting_time
        self.test()
        
        return self.history
    
    
    def train_FLASQ_reweight(self, lr, time_limit):
        real_time = time.time()
        self.diverged = False
        p = 1 / (self.GROUP_COUNT + 1) ## Server ratio in averaging
        server_model_ratio_on_client = p if self.client_averaging else 1
        server_model_ratio_on_server = p if self.server_averaging else 0
        client_dictionary_mode = "state" #if server_averaging else "gradient"
        self.test()
        while(self.server.time < time_limit):
            if(self.server.time - self.last_tested >= self.log_period or 
               (self.server.time - self.last_tested >= 10 and self.server.time <= 100)):
                self.test()
                print(f"Real time: {time.time() - real_time}")
                real_time = time.time()

            interaction_group = random.sample(self.clients, self.GROUP_COUNT)
            server_SD = self.server.get_model_SD(quantized = True)
            for client in interaction_group:

                max_steps = self.LOCAL_STEP ##random.randint(1, self.LOCAL_STEP)
                
                initial_client_dict = client.get_model_dictionary(quantized = True, mode=client_dictionary_mode)

                run_time, taken_steps = client.run_until(lr* (1 / (1 - p) if self.server_averaging else 1) ,
                                                         self.server.time, max_steps)

                client_dict = client.get_model_dictionary(quantized = True, mode=client_dictionary_mode)
                if taken_steps==0:
                    self.send_dict_to_server(client_dict) #NOweight
                else:
                    for key in initial_client_dict:
                        initial_client_dict[key]   = (1/taken_steps)*client_dict[key] + (1-1/taken_steps)*initial_client_dict[key]
                    self.send_dict_to_server(initial_client_dict) #NOweight

                self.server.seen_local_steps += taken_steps
                
                client.time = self.server.time

            self.server.average_received_SDs(server_model_ratio = server_model_ratio_on_server)
            server_SD = self.server.get_model_SD(quantized = True)
            for client in interaction_group:
                client.average_with_server_SD(server_SD, server_model_ratio_on_client)

            self.server.interaction_count += 1
            self.server.time += Server.server_interaction_time
            self.server.time += Trainer.server_waiting_time
        self.test()
        
        return self.history
    
    
    def train_FLASQ_reweight_avg(self, lr, time_limit):
        real_time = time.time()
        self.diverged = False
        p = 1 / (self.GROUP_COUNT + 1) ## Server ratio in averaging
        server_model_ratio_on_client = p if self.client_averaging else 1
        server_model_ratio_on_server = p if self.server_averaging else 0
        client_dictionary_mode = "state" #if server_averaging else "gradient"
        self.test()
        while(self.server.time < time_limit):
            if(self.server.time - self.last_tested >= self.log_period or 
               (self.server.time - self.last_tested >= 10 and self.server.time <= 100)):
                self.test()
                print(f"Real time: {time.time() - real_time}")
                real_time = time.time()

            interaction_group = random.sample(self.clients, self.GROUP_COUNT)
            server_SD = self.server.get_model_SD(quantized = True)
            for client in interaction_group:

                max_steps = self.LOCAL_STEP ##random.randint(1, self.LOCAL_STEP)
                
                initial_client_dict = client.get_model_dictionary(quantized = True, mode=client_dictionary_mode)

                run_time, taken_steps = client.run_until(lr* (1 / (1 - p) if self.server_averaging else 1) ,
                                                         self.server.time, max_steps)

                client_dict = client.get_model_dictionary(quantized = True, mode=client_dictionary_mode)
                
                #print(client.mean_step_time, taken_steps)
                #print("1/\pe E^i_t", self.GROUP_COUNT*client.mean_step_time/self.client_count)
                for key in initial_client_dict:
                    initial_client_dict[key] = (self.GROUP_COUNT*client.mean_step_time/self.client_count)*client_dict[key] + (1-self.GROUP_COUNT*client.mean_step_time/self.client_count)*initial_client_dict[key]
                self.send_dict_to_server(initial_client_dict) #NOweight

                self.server.seen_local_steps += taken_steps
                
                client.time = self.server.time

            self.server.average_received_SDs(server_model_ratio = server_model_ratio_on_server)
            server_SD = self.server.get_model_SD(quantized = True)
            for client in interaction_group:
                client.average_with_server_SD(server_SD, server_model_ratio_on_client)

            self.server.interaction_count += 1
            self.server.time += Server.server_interaction_time
            self.server.time += Trainer.server_waiting_time
        self.test()
        
        return self.history
    
    
    def train_AsyncSGD_v2(self, lr, time_limit, buffersize=10):
        self.diverged = False
        self.test()
        indices = [i for i in range(self.client_count)]
        run_times = [0 for i in range(self.client_count)]
        stacked_models = [0 for i in range(self.client_count)]
        #stacking_counters = [1 for i in range(self.client_count)]
        stacking_counters = [0 for i in range(self.client_count)]
        delayed_info = [(0, 0) for i in range(self.client_count)]
        while(self.server.time < time_limit):
            if(self.server.time - self.last_tested >= self.log_period or 
               (self.server.time - self.last_tested >= 10 and self.server.time <= 100)):
                self.test()

            server_SD = self.server.get_model_SD(quantized = True)
            for index in indices:
                client = self.clients[index]
                if stacking_counters[index] == 0:
                    client.load_SD(server_SD, quantized = True)
                    run_time, taken_steps = client.take_step(self.LOCAL_STEP, lr)
                    self.server.seen_local_steps += taken_steps
                    run_times[index] += run_time
                    stacking_counters[index] = 1
                else:# stacking_counters[index] > 0:
                    local_model = client.get_model_SD(quantized = True)
                    client.load_SD(server_SD, quantized = True)
                    run_time, taken_steps = client.take_step(self.LOCAL_STEP, lr)
                    delayed_info[index] = (taken_steps, run_time)
                    stacked_models[index] = client.get_model_SD(quantized = True)
                    client.load_SD(local_model, quantized = True)
                    stacking_counters[index] += 1
                                     
                        
            run_times_active_clients = [0 for i in range(self.client_count)]
            for i in range(self.client_count):
                if stacking_counters[i] > 0:
                    run_times_active_clients[i] = run_times[i]
                else:
                    run_times_active_clients[i] = 2*time_limit
            values, indices = torch.topk(-torch.tensor(run_times_active_clients), self.GROUP_COUNT)
            #print('lowest run times', -values)
            #print('INDICES', indices)
            #print('Are fast=', [i > int(Trainer.SLOW_CLIENTS_RATIO * self.client_count) for i in indices])
            for index in indices:
                client = self.clients[index]
                run_times[index] = max(-values)
                client_SD = client.get_model_SD(quantized = True)
                if stacking_counters[index] == 1:
                    reweight = 1
                    stacking_counters[index] = 0
                else:
                    reweight = stacking_counters[index] - 1
                    stacking_counters[index] = 1
                    self.server.seen_local_steps += delayed_info[index][0]
                    run_times[index] += delayed_info[index][1]
                    client.load_SD(stacked_models[index], quantized = True)
                    stacked_models[index] = 0
                for key in client_SD:
                    client_SD[key] = 1 * client_SD[key]
                    #client_SD[key] = reweight * client_SD[key]
                self.send_dict_to_server(client_SD)
            indices = random.sample([i for i in range(self.client_count)], self.GROUP_COUNT)
            for index in indices:
                if stacking_counters[index] == 0:
                    run_times[index] = max(-values)
            self.server.time = max(-values)
            self.server.time += Server.server_interaction_time
            self.server.average_received_SDs(server_model_ratio = 0)
            self.server.interaction_count += 1
        self.test()

        return self.history
    
    def test(self):
        loss, acc = 0, 0
        model = self.server.model
        result = evaluate_on_dataloader(model, self.dataset_name, self.test_loader)
        result['Time'] = self.server.time
        result['Server steps'] = self.server.interaction_count
        result['Local steps'] = self.server.seen_local_steps
        result['Variance'] = torch.linalg.norm(torch.nn.utils.parameters_to_vector(self.server.model.parameters())-torch.stack([client.get_model(quantized = True) for client in self.clients]), dim=1).sum().cpu().detach().numpy()
        print(result['Variance'])
        
        self.last_tested = self.server.time
        time = self.server.time
        loss, acc = result["Loss"], result["Accuracy"]
        print('Train: Step: {:5.0f} Val-Loss: {:.4f}  Val-Acc: {:.2f} Time: {:6.2f} Local Steps: {:5.0f}'.format(self.server.interaction_count, loss, acc, time, self.server.seen_local_steps))
        self.history.append(result)
        
    def send_dict_to_server(self, client_dict):
        self.server.received_dicts.append(client_dict)
    
    def train(self, lr, time_limit):
        if self.algorithm == "our":
            return self.train_our(lr, time_limit)
        if self.algorithm == "quantized_fl":
            return self.train_quantized_fl(lr, time_limit)
        elif self.algorithm == "Fed_Avg":
            return self.train_Fed_Avg(lr, time_limit)
        elif self.algorithm == "Fed_Buff":
            return self.train_Fed_Buff(lr, time_limit)
        elif self.algorithm == "langevin":
            return self.train_langevin(lr, time_limit)
        elif self.algorithm == "FLASQ_reweight":
            return self.train_FLASQ_reweight(lr, time_limit)
        elif self.algorithm == "FLASQ_reweight_avg":
            return self.train_FLASQ_reweight_avg(lr, time_limit)
        elif self.algorithm == "AsyncSGD_v2":
            return self.train_AsyncSGD_v2(lr, time_limit)
        