import torch
import torch.optim as optim
import torch.nn as nn
import copy
import random
from torch.utils.data import Subset

from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR

from models.gen_dis import Discriminator, Generator
from torch.utils.data import ConcatDataset, Dataset, DataLoader

from utils.toolkit import dirichlet_split
from collections import defaultdict
import numpy as np
import torch.nn.functional as F



class client_SRFDIL():
    def __init__(self, model, train_dataset, name, args):
        self.args = args
        self.model = copy.deepcopy(model)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)
        self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr_local, momentum=0.9)
        self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.96)
        self.theta_reg = None
        self.init_weights = None
        self.initialize_model()
        self.name = name
        self.train_dataset = train_dataset
        self.generator = Generator(noise_dim=100, img_channels=3, img_size=32 if self.args.dataset_list == 'DIGIT10' else 224)
        self.discriminator = Discriminator(img_channels=3, img_size=32 if self.args.dataset_list == 'DIGIT10' else 224)
        self.discriminator = self.discriminator.to(self.device)
        self.generator = self.generator.to(self.device)

    def set_theta_reg(self, theta_reg):
        self.theta_reg = theta_reg


    def update_dis(self, discriminator):
        if discriminator is not None:
            self.discriminator.load_state_dict(discriminator.state_dict())

    def _random_init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def initialize_model(self):
        if self.init_weights is not None:
            self.model.load_state_dict(self.init_weights)
        else:
            self.model.apply(self._random_init_weights)


    def update_model(self, new_model):
        self.model.load_state_dict(new_model.state_dict())

    def update_train_date(self, model, discriminator, previous_data, num_classes):
        label_dict = defaultdict(list)
        idx = 0
        for item in previous_data:
            label_dict[item[1]].append(idx)
            idx += 1
        model.eval()
        discriminator.eval()
        previous_dataloader = DataLoader(previous_data, batch_size=self.args.batch_size, shuffle=False)
        score = []
        with torch.no_grad():
            for img, _ in previous_dataloader:
                img = img.to(self.args.device)
                global_score_tensor = discriminator(img)
                global_score_list = global_score_tensor.tolist()
                global_score = [item[0] for item in global_score_list]
                score.extend(global_score)

            for label in range(num_classes):
                data_index_set = label_dict[label]
                train_label_data = Subset(previous_data, data_index_set)
                train_label_dataloader = DataLoader(train_label_data, batch_size=self.args.batch_size, shuffle=False)
                output_set = []
                for img, _ in train_label_dataloader:
                    img = img.to(self.args.device)
                    output = model(img)
                    output_set.extend(output.tolist())
                output_label = [sum(item) / len(item) for item in zip(*output_set)]

                score_sublist = []
                for img, _ in train_label_dataloader:
                    img = img.to(self.args.device)
                    output = model(img)
                    output = output.tolist()
                    output_label_tensor = torch.tensor(output_label)
                    score_sublist.extend(
                        [F.cosine_similarity(torch.tensor(item).unsqueeze(0), output_label_tensor.unsqueeze(0)).item()
                         for
                         item in output])
                for index in range(len(data_index_set)):
                    score[data_index_set[index]] += score_sublist[index]
        selected_size = int(len(previous_data) / 10)
        selected_size = min(selected_size, 1)
        selected_index = list(np.argsort(score)[-selected_size:][::-1])
        selected_sample = Subset(previous_data, selected_index)

        self.train_dataset = ConcatDataset([self.train_dataset, selected_sample])
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.args.batch_size,
                                                        shuffle=True,
                                                        drop_last=True)

    def trainI(self):
        self.model = self.model.to(self.device)
        self.model.train()
        for epoch in range(self.args.epochs):
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                base_loss = nn.CrossEntropyLoss()
                loss = base_loss(outputs, labels)
                loss.backward()
                self.optimizer.step()
        self.scheduler.step()


    def trainII(self):
        self.model = self.model.to(self.device)
        self.model.train()
        real_data = Subset(self.train_dataset, random.sample(range(len(self.train_dataset)), min(32, len(self.train_dataset))))
        real_dataloader = torch.utils.data.DataLoader(real_data, batch_size=self.args.batch_size, shuffle=True,
                                                      drop_last=False)
        self.generator.to(self.device)
        self.discriminator.to(self.device)
        loss_BCE = nn.BCELoss()
        gen_optimizer = optim.Adam(self.generator.parameters(), lr=1e-4)
        dis_optimizer = optim.Adam(self.discriminator.parameters(), lr=1e-4)
        self.generator.train()
        self.discriminator.train()
        for epoch in range(self.args.epochs):
            # train for model
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.model.zero_grad()
                self.optimizer.zero_grad()
                outputs = self.model(images)
                base_loss = nn.CrossEntropyLoss()
                loss = base_loss(outputs, labels)
                loss.backward()
                self.optimizer.step()
            # train for generator and discriminator
            for real_images, _ in real_dataloader:
                real_images = real_images.to(self.device)
                batch_size = real_images.size(0)
                real_labels = torch.ones(batch_size, 1).to(self.device)
                fake_labels = torch.zeros(batch_size, 1).to(self.device)

                # train for discriminator
                noise = torch.randn(batch_size, 100).to(self.device)
                fake_images = self.generator(noise)

                d_real = self.discriminator(real_images)
                d_fake = self.discriminator(fake_images)

                d_loss = loss_BCE(d_real, real_labels) + loss_BCE(d_fake, fake_labels)
                dis_optimizer.zero_grad()
                d_loss.backward()
                dis_optimizer.step()

                # train for generator
                noise = torch.randn(batch_size, 100).to(self.device)
                fake_images = self.generator(noise)
                g_loss = loss_BCE(self.discriminator(fake_images), real_labels)
                gen_optimizer.zero_grad()
                g_loss.backward()
                gen_optimizer.step()

        self.scheduler.step()






