import torch
from args import parse
import argparse
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from torch import nn
from utils import *

class Client:
    def __init__(self,parse:argparse.ArgumentParser,client_id:int,c_T:int,logger):
        self.args = parse.parse_args()
        self.client_id = client_id
        self.c_T = c_T
        self.model = self.load_client_model()
        self.device = self.args.device

        self.client_epochs = self.args.client_epochs
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.learning_rate, momentum=0.9)
        

        self.batch_size = self.args.batch_size
        self.train_loader,self.val_loader,self.test_loader = None,None,None
        self.get_dataloader()

        self.logger = logger


    def load_client_model(self):
        self.model_path = self.args.model_save_path + f"/{self.client_id}.pth"
        return torch.load(self.model_path)

    def save_client_model(self):
        torch.save(self.model, self.model_path)

    def load_client_data(self):
        dataset_list = []
        for str in ["train","val","test"]:
            data_path = self.args.data_save_path + f"/{self.client_id}_{str}.pth"
            dataset_list.append(torch.load(data_path))
        return dataset_list

    def get_dataloader(self):
        train_dataset,val_dataset,test_dataset = self.load_client_data()

        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        self.val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

    def train(self,public_loader):
        self.model.to(self.device)
        self.logger.info(f"---------- client {self.client_id} start training ----------")
        sum_acc = []
        sum_loss = []
        for epoch in range(self.client_epochs):
            self.model.train()
            running_loss = 0.0
            running_corrects = 0
            for inputs, labels in self.train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == labels.data)

            train_loss = running_loss / len(self.train_loader.dataset)
            train_acc = running_corrects.double() / len(self.train_loader.dataset)
            val_loss,val_acc = self.evaluate(self.test_loader)
            sum_acc.append(val_acc)
            sum_loss.append(val_loss)
            self.logger.info(f"--client: {self.client_id} --epoch:{epoch+1}/{self.client_epochs} --train_loss :{train_loss:.4f} --train_acc :{train_acc:.4f} --val_loss : {val_loss:.4f} --val_acc : {val_acc:.4f}")
        self.save_client_model()
        client_logits = self.get_logits(public_loader)
        avgacc = sum(sum_acc)/len(sum_acc)
        avgloss = sum(sum_loss)/len(sum_loss)
        maxacc = max(sum_acc)
        minloss = min(sum_loss)
        self.logger.info(f"--client: {self.client_id} --avgloss : {avgloss:.4f} --avgacc : {avgacc:.4f} --minloss : {minloss:.4f} -- maxacc : {maxacc:.4f}")
        return client_logits, avgloss, avgacc

    def evaluate(self, val_loader):
        self.model.eval()
        running_loss = 0.0
        running_corrects = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == labels.data)

        val_loss = running_loss / len(val_loader.dataset)
        val_acc = running_corrects.double() / len(val_loader.dataset)
        return val_loss, val_acc


    def get_logits(self, public_loader):
        self.model.eval()
        with torch.no_grad():
            logits_list = []
            for inputs, labels in public_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                logits_list.append(outputs)

        return torch.stack(logits_list, dim=0)


    def renew_client_model(self,loss):
        self.model.train()
        self.optimizer.zero_grad()
        if loss.requires_grad:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
        self.save_client_model()
        del self.model
        torch.cuda.empty_cache()
