import os
import torch
from args import parse
import argparse
from Models.ab_exp import *
from Models.CNNs import *
from Models.resnets import *
from Models.Googlenet import *
from Models.Mobilenet import *
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from client import Client
from utils import *
import random
import torch.nn as nn


class Server:
    def __init__(self,parse:argparse.ArgumentParser,logger):
        self.args = parse.parse_args()
        self.num_clients = self.args.num_clients
        self.server_epochs = self.args.server_epochs
        self.batch_size = self.args.batch_size
        self.client_epochs = self.args.client_epochs
        self.join = self.args.join
        self.clientsID_list = [i+1 for i in range(self.num_clients)]
        self.client_logits = {}

        self.generate_clients_model(random_generate=self.args.random)

        self.public_train_loader, self.public_val_loader, self.public_test_loader = None, None, None
        self.get_public_dataloader()
        self.model_path = self.args.model_save_path + f"/server.pth"
        self.model = None
        self.init_server_model()
        self.device = self.args.device
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.learning_rate, momentum=0.9)

        self.kldiv = KLDivergenceLoss(temperature=self.args.KLtemperature, reduction="batchmean")
        self.logger = logger

    def get_out_dim(self):
        self.data_name = self.args.data_name
        if self.data_name == "cifar10":
            return 10
        elif self.data_name == "cifar100":
            return 100
        elif self.data_name == "tinyimagenet":
            return 200
        elif self.data_name == "flowers102":
            return 102
        else:
            pass


    def init_server_model(self):
        ## normal
        if self.data_name == 'cifar10' or self.data_name == 'cifar100':
            self.model = CNNWithMoE(self.get_out_dim(),self.args.num_experts,self.args.topK,self.args.dropout)
        elif self.data_name == "tinyimagenet" or self.data_name == "flowers102":
            self.model = CNNWithMoE_tiny(self.get_out_dim(), self.args.num_experts, self.args.topK, self.args.dropout)
        self.save_server_model()

    def load_server_model(self):
        return torch.load(self.model_path)

    def save_server_model(self):
        torch.save(self.model, self.model_path)

    def load_public_data(self):
        dataset_list = []
        for str in ["train","val","test"]:
            data_path = self.args.data_save_path + f"/server_{str}.pth"
            dataset_list.append(torch.load(data_path))
        return dataset_list
    def get_public_dataloader(self):
        train_dataset, val_dataset, test_dataset = self.load_public_data()
        batch_size = max(2, self.batch_size)
        
        self.public_train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True
        )
        
        self.public_val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            drop_last=True
        )
        
        self.public_test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False
        )

    def server_model_train(self):
        self.load_server_model()
        self.model.to(self.device)
        self.logger.info("---------- server start training ----------")
        server_acc = []
        server_loss = []
        for epoch in range(self.client_epochs):
            self.model.train()
            running_loss = 0.0
            running_corrects = 0
            for inputs, labels in self.public_train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == labels.data)
            train_loss = running_loss / len(self.public_train_loader.dataset)
            train_acc = running_corrects.double() / len(self.public_train_loader.dataset)

            test_loss, test_acc = self.evaluate(self.public_test_loader)
            server_acc.append(test_acc)
            server_loss.append(test_loss)
            self.logger.info(f"--server --epoch:{epoch+1}/{self.client_epochs} --train_loss :{train_loss:.4f} --train_acc :{train_acc:.4f} --test_loss : {test_loss:.4f} --test_acc : {test_acc:.4f}")
        server_logits = self.server_model_logits(self.public_val_loader)
        self.save_server_model()
        avgacc = sum(server_acc)/len(server_acc)
        avgloss = sum(server_loss)/len(server_loss)
        maxacc = max(server_acc)
        minloss = min(server_loss)
        self.logger.info(f"--server -- avgloss : {avgloss:.4f} -- avgacc : {avgacc:.4f} -- minloss : {minloss:.4f} -- maxacc : {maxacc:.4f}")
        self.logger.info("---------- server end training ----------\n")
        return server_logits

    def evaluate(self, val_loader):
        self.model.eval()
        running_loss = 0.0
        running_corrects = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == labels.data)
        val_loss = running_loss / len(val_loader.dataset)
        val_acc = running_corrects.double() / len(val_loader.dataset)
        return val_loss, val_acc

    def server_model_logits(self, public_loader):
        self.model.eval()
        with torch.no_grad():
            server_logits_list = []
            for inputs, labels in public_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                server_logits_list.append(outputs)
        return torch.stack(server_logits_list, dim=0)

    def generate_clients_model(self,random_generate=True,model_list=None):
        model_save_path = "./save/model"
        if not os.path.exists(model_save_path):
            os.makedirs(model_save_path)


        out_dim = self.get_out_dim()
        if random_generate:
            if self.data_name == 'cifar10':
                model_1 = CNN_1
            elif self.data_name == 'cifar100':
                model_1 = CNN_1
            elif self.data_name == 'tinyimagenet':
                model_1 = CNN_tiny
            elif self.data_name == 'flowers102':
                model_1 = CNN_tiny

            model_2 = ResNet18
            model_3 = ResNet34
            model_4 = ResNet50
            model_5 = ResNet101
            model_6 = ResNet152
            model_7 = GoogleNet
            model_8 = MobileNet
            model_list = [model_1, model_2, model_3, model_4, model_5, model_6, model_7, model_8]

            for id in self.clientsID_list:
                model_index = id % 8
                selected_model = model_list[model_index]
                client_model = selected_model(out_dim=out_dim)
                model_path = self.args.model_save_path + f"/{id}.pth"
                torch.save(client_model, model_path)

    def aggregation(self,server_logits):
        if not server_logits.requires_grad:
            server_logits.requires_grad_(True)
        total_loss = torch.tensor(0.0, device=self.device)
        for k,v in self.client_logits.items():
            if not v.requires_grad:
                v.requires_grad_(True)
            total_loss += self.kldiv(server_logits,logits_smooth(v,self.args.temperature))
        middle_loss = total_loss / self.num_clients
        return middle_loss

    def renew_server_model(self,loss):
        self.model.train()
        self.optimizer.zero_grad()
        if loss.requires_grad:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = 1.0)
            self.optimizer.step()
        self.save_server_model()

    def train(self):
        for c_T in range(self.server_epochs):
            self.logger.info(f"=============== T:{c_T+1} start !!! ================\n")
            acc_all = []
            loss_all = []
            selected_clientsID = random.sample(self.clientsID_list,self.join)
            for id in selected_clientsID:
                logits, client_loss, client_acc = Client(parse=parse,client_id=id,c_T=c_T,logger=self.logger).train(self.public_val_loader)
                self.client_logits[id] = logits
                acc_all.append(client_acc)
                loss_all.append(client_loss)
            avgacc = sum(acc_all)/len(acc_all)
            avgloss = sum(loss_all)/len(loss_all)
            maxacc = max(acc_all)
            minloss = min(loss_all)
            server_logits = logits_smooth(self.server_model_train(),self.args.temperature)
            middle_loss = self.aggregation(server_logits)
            self.renew_server_model(middle_loss)
            final_logits = logits_smooth(self.server_model_logits(self.public_val_loader), self.args.temperature)

            for id in self.clientsID_list:
                Client(parse=parse, client_id=id, c_T=c_T,logger=self.logger).renew_client_model(self.kldiv(self.client_logits[id],final_logits))
            download_cost = calculate_communication_cost(self.client_logits[1]) * self.num_clients
            upload_cost = calculate_communication_cost(final_logits)

            self.logger.info("++++++++++++++++++++ T_avg ++++++++++++++++++++")
            self.logger.info(f"+++++ client -- avgloss:{avgloss} -- avgacc:{avgacc} -- minloss:{minloss} -- maxacc:{maxacc}")
            self.logger.info("++++++++++++++++++++ communication cost ++++++++++++++++++++")
            self.logger.info(f"+++++ server -- download:{download_cost}MB  -- upload:{upload_cost}MB\n")
            torch.cuda.empty_cache()


    def ab_aggregation(self,server_logits):
        """
        simple aggregate logits
        :param :
        :return:
        """
        if not server_logits.requires_grad:
            server_logits.requires_grad_(True)
        total_loss = torch.tensor(0.0, device=self.device)
        clients_agglogits = torch.tensor(0.0, device=self.device)
        for k,v in self.client_logits.items():
            if not v.requires_grad:
                v.requires_grad_(True)
            clients_agglogits += logits_smooth(v,self.args.temperature)
        middle_loss = KLDivergenceLoss(server_logits,clients_agglogits / self.num_clients)
        return middle_loss