import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from sklearn.metrics import accuracy_score
import torch.optim as optim
import numpy as np
import os

def save_model_in_client_list(client_list,client_index_of_each_group,Result_log_folder,dataset_name):
    print("start save model")
    father_path = Result_log_folder +dataset_name+'/model_path/'
    if not os.path.exists(father_path):
        os.makedirs(father_path)
 
    for client_index_list in client_index_of_each_group:
        model_save_path = father_path+dataset_name+'Best_model_'+str(client_index_list[0])
        torch.save(client_list[client_index_list[0]].get_model_weight(),model_save_path)
    print("finish")


def calculate_acc(client_list):
    average_train_acc_list = []
    average_test_acc_list = []
 
    for client in client_list:
        average_train_acc_list.append(client.get_train_acc())
        average_test_acc_list.append(client.get_test_acc())

    return np.asarray(average_train_acc_list), np.asarray(average_test_acc_list)


def covert_numpy_data_to_torch_dataLoader(X_data, y_data, device, batch_size = 128):
    X_data = torch.from_numpy(X_data).to(device)
    X_data.requires_grad = False
    y_data = torch.from_numpy(y_data).to(device)

    if len(X_data.shape) == 2:
        X_data = X_data.unsqueeze_(1)

    dataset = TensorDataset(X_data, y_data)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return data_loader


class Model_trainner():
    def __init__(self, model, X_train, y_train, device, lr):
        self.lr = lr
        self.model = model
        self.device = device
        self.train_loader = covert_numpy_data_to_torch_dataLoader(X_train, y_train, self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        #self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
        
    def forward_n_setp(self, steps):
        self.model.to(self.device).train()
        running_loss_list = []
        for i in range(steps):
            running_loss_list.append(0)
            count = 0 
            for sample in self.train_loader:
#                 if torch.unique(sample[1]).shape[0] == 1:
#                     continue
                self.optimizer.zero_grad()
                y_predict = self.model(sample[0])
                output = self.criterion(y_predict, sample[1])
                output.backward()
                self.optimizer.step()
                running_loss_list[-1] += output.item()*sample[0].shape[0]
                count = count+sample[0].shape[0]
            running_loss_list[-1] = running_loss_list[-1]/count
        self.model.cpu()
        torch.cuda.empty_cache()
        return running_loss_list

class Client:
    def __init__(self,device):
        self.device = device
        self.running_loss_list = []

    def get_data(self, X_train, y_train, X_test, y_test):
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test

        
        
    def initialize_model(self, model):
        self.model = model
        
        
    def initialize_model_trainner(self,lr = 0.001):
        self.model_trainner = Model_trainner(self.model, self.X_train, self.y_train, self.device, lr)
        
        
    def train_n_steps(self,n):
        running_loss_list = self.model_trainner.forward_n_setp(n)
        self.running_loss_list = self.running_loss_list + running_loss_list
        
    def load_model_weight(self, weight_dict):
        self.model.load_state_dict(weight_dict)
        
        
    def get_model_weight(self):
        return self.model.state_dict()
        
        
    def get_train_acc(self):
        self.model.eval()
        predict_result = self.predict(self.X_train)
        acc = accuracy_score(predict_result,self.y_train)
        return acc
    
    
    def get_test_acc(self):
        self.model.eval()
        predict_result = self.predict(self.X_test)
        acc = accuracy_score(predict_result,self.y_test)
        return acc
        
        
    def predict(self, X_input, output_probability = False):
        self.model.eval()
        
        X_input = torch.from_numpy(X_input).to(self.device)
        X_input.requires_grad = False
        if len(X_input.shape) == 2:
            X_input = X_input.unsqueeze_(1)
        
        predict_dataset = TensorDataset(X_input)
        predict_loader = DataLoader(predict_dataset, batch_size=16, shuffle=False)
        
        predict_list = np.array([])
        
        self.model.to(self.device)
        i = 0 
        for sample in predict_loader:
            y_predict = self.model(sample[0])
            y_predict = y_predict.detach().cpu().numpy()
            if output_probability == False:
                y_predict = np.argmax(y_predict, axis=1)
            if i == 0:
                predict_list = y_predict
                i = i+1
            else:    
                predict_list = np.concatenate((predict_list, y_predict), axis=0)
            
        X_input.cpu()
        self.model.cpu()
        torch.cuda.empty_cache()
        return predict_list        