import torch
import torch.optim as optim
import torch.nn as nn
import copy
import numpy as np
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from clients.toolkit import combine_data


class client_MFCL():
    def __init__(self, model, train_dataset, name, args):
        self.args = args
        self.model = copy.deepcopy(model)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr_local, momentum=0.9, weight_decay=1e-5)
        self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.96)
        self.theta_reg = None
        self.init_weights = None
        self.initialize_model()
        self.name = name
        self.last_valid_dim = 0
        self.valid_dim = 0
        self.mu = 1e-1
        self.ft_weight = 1
        self.syn_size = 128 if args.dataset_list == 'DIGIT10' else 16
        self.kd_criterion = nn.MSELoss(reduction='none')


    def set_theta_reg(self, theta_reg):
        self.theta_reg = theta_reg  # set the model of the last task as the initial point

    def _random_init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def initialize_model(self):
        if self.init_weights is not None:
            self.model.load_state_dict(self.init_weights)
        else:
            self.model.apply(self._random_init_weights)
    def update_model(self, new_model):
        self.model.load_state_dict(new_model.state_dict())

    def trainI(self):
        self.model = self.model.to(self.device)
        self.model.train()
        for epoch in range(self.args.epochs):
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.model.zero_grad()
                self.optimizer.zero_grad()
                outputs = self.model(images)
                base_loss = nn.CrossEntropyLoss()
                loss = base_loss(outputs, labels)
                loss.backward()
                self.optimizer.step()
        self.scheduler.step()

    def trainII(self, teacher):
        self.model = self.model.to(self.device)
        self.dw_k = torch.ones((self.valid_dim + 1), dtype=torch.float32)
        self.model.train()
        previous_teacher, previous_linear = copy.deepcopy(teacher[0]), copy.deepcopy(teacher[1])
        for epoch in range(self.args.epochs):
            for x, y in self.train_loader:
                x, y = x.to('cuda'), y.to('cuda')
                self.model.zero_grad()
                self.optimizer.zero_grad()
                idx1 = torch.where(y >= self.last_valid_dim)[0]
                x_replay, y_replay, y_replay_hat = self.sample(previous_teacher, self.syn_size)
                y_hat = previous_teacher.generate_scores(x, allowed_predictions=np.arange(self.last_valid_dim))
                _, y_hat_com = combine_data(((x, y_hat), (x_replay, y_replay_hat)))
                x_com, y_com = combine_data(((x, y), (x_replay, y_replay)))
                logits_pen = self.model.feature(x_com)
                logits = self.model.fc(logits_pen)
                mappings = torch.ones(self.valid_dim, dtype=torch.float32, device='cuda')
                dw_cls = mappings[y_com.long()]
                # loss_class = self.criterion(logits[idx1, self.last_valid_dim:self.valid_dim], (y_com[idx1] - self.last_valid_dim), dw_cls[idx1])
                outputs = self.model(x)
                base_loss = nn.CrossEntropyLoss()
                loss_class = base_loss(outputs, y)
                with torch.no_grad():
                    feat_class = self.model.feature(x_com).detach()
                loss_ft = self.criterion(self.model.fc(feat_class), y_com, dw_cls) * self.ft_weight
                loss_kd = self.kd(x_com, previous_linear, logits_pen, previous_teacher)
                total_loss = 5 * loss_class + 0.5 * loss_ft + 0.5 * loss_kd
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()


    def train(self):
        self.model = self.model.to(self.device)
        self.model.train()
        for epoch in range(self.args.epochs):
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.model.zero_grad()
                self.optimizer.zero_grad()
                outputs = self.model(images)
                base_loss = nn.CrossEntropyLoss()
                loss = base_loss(outputs, labels)
                loss.backward()
                self.optimizer.step()
        self.scheduler.step()

    def kd(self, x_com, previous_linear, logits_pen, previous_teacher):
        kd_index = np.arange(x_com.size(0))
        dw_KD = self.dw_k[-1 * torch.ones(len(kd_index), ).long()].to('cuda')
        logits_KD = previous_linear(logits_pen[kd_index])[:, :self.last_valid_dim]
        logits_KD_past = previous_linear(previous_teacher.generate_scores_pen(x_com[kd_index]))[:,
                         :self.last_valid_dim]
        loss_kd = self.mu * (self.kd_criterion(logits_KD, logits_KD_past).sum(dim=1) * dw_KD).mean() / (
            logits_KD.size(1))
        return loss_kd

    def sample(self, teacher, dim, return_scores=True):
        return teacher.sample(dim, return_scores=return_scores)

    def criterion(self, logits, targets, data_weights):
        base_loss = nn.CrossEntropyLoss()
        return (base_loss(logits, targets) * data_weights).mean()



