import torch
import torch.optim as optim
import torch.nn as nn
import copy
import random
from models.cnn import CNN
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR


class client_pFedDIL():
    def __init__(self, model, train_dataset, name, args):
        self.aux_classifier_score = None
        self.args = args
        self.model = copy.deepcopy(model)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr_local, momentum=0.9)
        self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.96)

        self.theta_reg = None
        self.init_weights = None
        self.initialize_model()
        self.name = name
        self.train_dataset = train_dataset
        self.auxiliary_classifier_current_task = CNN(num_classes=2).to(self.device)

    def set_theta_reg(self, theta_reg):
        self.theta_reg = theta_reg

    def _random_init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def initialize_model(self):
        if self.init_weights is not None:
            self.model.load_state_dict(self.init_weights)
        else:
            self.model.apply(self._random_init_weights)

    def update_model(self, new_model):
        self.model.load_state_dict(new_model.state_dict())

    def update_model2(self, new_model):
        self.model.load_state_dict(new_model)

    def calculate_aux_score(self, auxiliary_classifier):
        all_score = []
        for model_idx, pre_model in enumerate(auxiliary_classifier):
            with torch.no_grad():
                score = []
                for batch_idx, (images, labels) in enumerate(self.train_loader):
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = pre_model(images)
                    probs = torch.sigmoid(outputs)

                    score.append(probs[:, 1].detach().cpu())

            if len(score) > 0:
                mean_score = torch.cat(score).mean().item()
            else:
                print(f"[Warning] No valid score collected for model #{model_idx}, assigning 0.")
                mean_score = 0.0

            all_score.append(mean_score)

        self.aux_classifier_score = all_score

    def train_with_aux(self, future_data, previous_model):
        self.model = self.model.to(self.device)
        self.auxiliary_classifier_current_task = self.auxiliary_classifier_current_task.to(self.device)
        self.model.train()
        self.auxiliary_classifier_current_task.train()
        future_data_relabel = [(img, 0) for (img, _) in future_data]
        current_data_relabel = [(img, 1) for (img, _) in self.train_dataset]
        min_len = min(len(current_data_relabel), len(future_data_relabel))
        if len(future_data_relabel) > len(current_data_relabel):
            future_data_relabel = random.sample(future_data_relabel, min_len)
        else:
            current_data_relabel = random.sample(current_data_relabel, min_len)
        data_relabel = future_data_relabel + current_data_relabel
        data_relabel_loader = torch.utils.data.DataLoader(data_relabel, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
        for epoch in range(self.args.epochs):
            # update theta
            for images, labels in data_relabel_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                self.auxiliary_classifier_current_task.zero_grad()
                outputs = self.auxiliary_classifier_current_task(images)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()

            # update w
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)

                loss = nn.CrossEntropyLoss()(outputs, labels)
                proximal_term = 0.0
                for idx, pre_model in enumerate(previous_model):
                    pre_state_dict = pre_model.state_dict()
                    for name, theta in self.model.named_parameters():
                        if name in pre_state_dict:
                            theta_prev = pre_state_dict[name].to(theta.device)
                            proximal_term += (theta - theta_prev).norm(2) * self.aux_classifier_score[idx]

                loss += proximal_term
                loss.backward()
                self.optimizer.step()
            self.scheduler.step()

    def train_without_aux(self, previous_model):
        self.model = self.model.to(self.device)
        self.model.train()
        for epoch in range(self.args.epochs):

            # update w
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)

                loss = nn.CrossEntropyLoss()(outputs, labels)
                proximal_term = 0.0
                for idx, pre_model in enumerate(previous_model):
                    pre_state_dict = pre_model.state_dict()
                    for name, theta in self.model.named_parameters():
                        if name in pre_state_dict:
                            theta_prev = pre_state_dict[name].to(theta.device)
                            proximal_term += (theta - theta_prev).norm(2) * self.aux_classifier_score[idx]

                loss += proximal_term
                loss.backward()
                self.optimizer.step()
            self.scheduler.step()

    def train_first_task(self, future_data):
        self.model = self.model.to(self.device)
        self.auxiliary_classifier_current_task = self.auxiliary_classifier_current_task.to(self.device)
        self.model.train()
        self.auxiliary_classifier_current_task.train()
        future_data_relabel = [(img, 0) for (img, _) in future_data]
        current_data_relabel = [(img, 1) for (img, _) in self.train_dataset]
        min_len = min(len(current_data_relabel), len(future_data_relabel))
        if len(future_data_relabel) > len(current_data_relabel):
            future_data_relabel = random.sample(future_data_relabel, min_len)
        else:
            current_data_relabel = random.sample(current_data_relabel, min_len)
        data_relabel = future_data_relabel + current_data_relabel
        data_relabel_loader = torch.utils.data.DataLoader(data_relabel, batch_size=self.args.batch_size, shuffle=True,
                                                          drop_last=True)
        for epoch in range(self.args.epochs):
            # update theta
            for images, labels in data_relabel_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                self.auxiliary_classifier_current_task.zero_grad()
                outputs = self.auxiliary_classifier_current_task(images)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()

            # update w
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                base_loss = nn.CrossEntropyLoss()
                loss = base_loss(outputs, labels)
                loss.backward()
                self.optimizer.step()
            self.scheduler.step()







