import copy
import os.path
import random
import sys
import time

sys.path.append('dd')
from einops import rearrange, repeat

from .ufedbase import UnlearnBasicClient, UnlearnBasicServer
import numpy as np
from utils import fmodule
import torch
import torch.nn as nn
import collections
import json
from tqdm import tqdm
from utils.utils_unlearn import agg_func
from utils.finch import FINCH
import os
from sklearn.cluster import KMeans
import torchvision.utils as tvu
import torchvision
import pickle
from utils.clustering import SVDD_clustering, DBSCAN_clustering, FINCH_clustering, k_means_clustering
import torch.optim as optim
from torchvision import transforms
from omegaconf import OmegaConf
from PIL import Image
from dd.ldm.models.diffusion.ddim import DDIMSampler
import importlib
from torch import autocast
from torchvision import transforms as T
import accelerate
from tqdm import tqdm, trange
from pytorch_lightning import seed_everything
from itertools import islice
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]  # 这是从deadiff里面读入的sd_model的state_dict

    module, cls = config.model.target.rsplit(".", 1)

    model = getattr(importlib.import_module(module, package=None), cls)(**config.model.get("params", dict()))  # 这个是blip diffusion
    m, u = model.load_state_dict(sd, strict=False, )
    print('loading done')
    model.cuda()
    model.eval()
    return model

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())
class Server(UnlearnBasicServer):
    def __init__(self, option, model, clients, data_loader, device=None):
        # init_model = copy.deepcopy(model)
        super(Server, self).__init__(option, model, clients, data_loader, device)
        self.label_name = self.dataloader.label_dict.names[:self.dataloader.target_class_num]
        self.deadiff, self.sampler = self.load_style_extractor()
        for c in self.clients:
            c.deadiff = self.deadiff
        self.gen_bs = 1 # TODO:args

        try:
            self.aux_data = self.load_aux_data('./utils/data/Gen/testdomain')
            print('load auxiliary data successfully')

        except:
            print('load auxiliary data failed')

            if os.path.exists('./syf.pkl') and True:
                with open('./syf.pkl', 'rb') as file:
                    features = pickle.load(file)
                self.local_features = features

                # self.local_features = features[0]
                # self.global_features = features[1]
            else:
                self.local_features, local_mean = self.get_local_features()
                print(len(self.local_features))
                torch.cuda.empty_cache()

                with open('./syf.pkl', 'wb') as file:
                    pickle.dump(self.local_features, file)

        # self.domain_gen(ref_images=self.local_features, )
        self.re_rounds = 10 # TODO: 需要调
        # self.server_data = self.domain_gen()
        self.criterion = torch.nn.CrossEntropyLoss()
        # self.model = init_model.to(self.device)

    def load_aux_data(self, path):

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(
                (224, 224),
                interpolation=transforms.InterpolationMode.BICUBIC,
                antialias=False,
            ),
            transforms.Normalize(
                [0.48145466, 0.4578275, 0.40821073],
                [0.26862954, 0.26130258, 0.27577711]),
        ])
        a_dataset = datasets.ImageFolder(root=path, transform=transform)
        a_dataloader = DataLoader(a_dataset, batch_size=32, shuffle=True, num_workers=4)
        return a_dataloader
    def load_style_extractor(self):
        # load model
        config = OmegaConf.load('dd/configs/inference_deadiff_512x512.yaml')
        model = load_model_from_config(config, f"dd/pretrained/deadiff_v1.ckpt")  # blip_diffsuion
        sampler = DDIMSampler(model)
        return model, sampler

    def get_local_features(self):
        local_features = []
        feature_centers = []
        for c in self.clients:
            local_feature, feature_center = c.get_class_proto(self.deadiff)
            local_features.append(local_feature)
            feature_centers.append(feature_center)
        return local_features, feature_centers

    def get_global_features(self, local_feas):
        classes = []
        for fea in local_feas:
            classes += list(fea.keys())
        classes = list(set(classes))

        global_features = [{i: None for i in range(self.dataloader.target_class_num)} for j in range(len(local_feas))]

        for cidx in range(len(local_feas)): # 遍历clients
            for i in local_feas[cidx].keys(): # 遍历local label
                global_features[cidx][i] = torch.mean(local_feas[cidx][i], dim=0, keepdim=True) # classes的平均features

        return global_features

    def run(self):
        self.current_rounds = 0
        # test_metric = self.test_on_clients(dataflag='test', model=self.model)
        # self.outFunc(t_metric=test_metric)
        self.stage = 'PT'
        for round in tqdm(range(1, self.p_rounds + 1), desc='Post-training Rounds'):
            self.current_rounds = round
            # federated post training
            self.pt_iterate()
            # syn
            # syn
            self.global_lr_scheduler(self.p_rounds)

            test_metric = self.test_on_clients(dataflag='test', model=self.model)
            self.outFunc(test_metric)
            self.save_log(self.out_log)
        self.save_ckp()


    def pt_iterate(self):
        # raise RuntimeError('error in Algorithm: This function must be rewritten in the child class. (该函数必须在子类中被重写！)')
        # Remove unlearned clients
        self.selected_clients = self.sample()
        # for uid in self.unlearn_clients_id:
        #     if uid in self.selected_clients:
        #         self.selected_clients.remove(uid)
        self.selected_clients = np.delete(self.selected_clients, np.where(self.selected_clients == self.unlearn_clients_id)[0])
        reply = self.communicate(self.selected_clients)
        # 按照self.selected_clients = self.received_clients
        models, losses = reply['model'], reply['loss']

        self.model = self.aggregate(models)
        server_model, _ = self.domain_recycle()
        self.model = 0.9 * self.model + 0.1 * server_model
        del models, server_model
        return

    # server端对比学习一次 + 用平衡data更新一次 并不是两次分开，而是用一个统一的loss，loss = 分类loss + 对比loss
    def domain_recycle(self):
        server_model = copy.deepcopy(self.model)
        server_model.train()
        optimizer = optim.SGD(server_model.parameters(), lr=self.lr, weight_decay=self.weight_decay, momentum=0)
        total_loss = 0.0
        for i in range(self.re_rounds):
            for step, (batch_x, batch_y) in enumerate(self.aux_data):
                server_model.zero_grad()
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)
                c = self.deadiff.get_learned_conditioning({
                    'target_text': 'A cat.',
                    'inp_image': 2 * (batch_x - 0.5),
                    'subject_text': ['content'] * len(batch_y),  # subject_text是规定抽取content还是style
                }, rse_only=True)
                local_content_features = c[1]
                del c

                outputs = server_model(batch_x, local_content_features)
                # 计算batch_x 的features 送入combineloss计算完整损失，然后回传梯度
                loss = self.criterion(outputs, batch_y) # 这里需要被替换
                loss.backward()
                optimizer.step()
                batch_mean_loss = loss.item()
                total_loss += batch_mean_loss * len(batch_y)
        del optimizer
        server_params_vector = torch.nn.utils.parameters_to_vector(server_model.parameters())
        self_params_vector = torch.nn.utils.parameters_to_vector(self.model.parameters())

        # 计算 g_s
        g_s = -1 * (server_params_vector - self_params_vector) / self.lr
        # g_s = -1 * (torch.nn.utils.parameters_to_vector(server_model) - torch.nn.utils.parameters_to_vector(self.model)) / self.lr
        return server_model, g_s

    def pack(self, client_id, model=None):
        if model is not None:
            return {
                "model": copy.deepcopy(model),
                "current_rounds": self.current_rounds,
                "lr": self.lr,
                "momentum": self.local_momentum,
                "weight_decay": self.weight_decay,
                "stage": self.stage,
                # "g_s": self.g_s
            }
        else:
            return {
                "model": copy.deepcopy(self.model),
                "current_rounds": self.current_rounds,
                "lr": self.lr,
                "momentum": self.local_momentum,
                "weight_decay": self.weight_decay,
                "stage": self.stage,
                # "g_s": self.g_s
            }


    def domain_gen(self, ref_images, prompts=None, seeds=None, repeat=1):
        '''
        ref_images所有clients ref的列表
        本地提取的是[style, content]
        只需要根据这个生成新的图片就可以了
        # [style, content] -> [[style, prompt], [content, prompt]]
        '''
        # gen_path = os.path.join('/', 'home', 'pc', 'HSS', 'wzc', 'myunlearn', 'utils', 'data', 'Gen', self.dataloader.name)
        # gen_path = f'/home/dell/workspace/wzc/myunlearn/utils/data/Gen/{self.dataloader.name}'
        gen_path = f'/home/dell/workspace/wzc/myunlearn/utils/data/Gen/testdomain'
        for ln in self.label_name:
            if not os.path.exists(os.path.join(gen_path, f'{ln}')):
                os.makedirs(os.path.join(gen_path, f'{ln}'))
        # grid_count = len(os.listdir(gen_path))

        gener = torch.Generator('cpu')
        gener.manual_seed(self.option['seed'])

        accelerator = accelerate.Accelerator()
        precision_scope = autocast
        for r in ref_images:
            ref_styles, ref_contents = r
            for ref_style, ref_content in zip(ref_styles, ref_contents):

                ref_style, ref_style = ref_style.to(self.device).unsqueeze(0), ref_style.to(self.device).unsqueeze(0)
                ref_content, ref_content = ref_content.to(self.device).unsqueeze(0), ref_content.to(self.device).unsqueeze(0)
                with torch.no_grad():
                    with precision_scope("cuda"):
                        with self.deadiff.ema_scope():
                            # [style, content] -> [[style, prompt], [content, prompt]]
                            for prompt in self.label_name:
                                seed = torch.randint(low=0, high=2 ** 31 - 1, size=(1,), generator=gener).item()
                                seed_everything(seed)
                                if not os.path.exists(os.path.join(gen_path, f'{prompt}')):
                                    os.makedirs(os.path.join(gen_path, f'{prompt}'))
                                # chunk_prompts = list(chunk(prompt, ref_style.shape[0]))
                                uc = self.deadiff.get_learned_conditioning({'target_text': ref_style.shape[0] * [
                                    "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"]})

                                # if isinstance(prompts, tuple):
                                #     prompt = list(prompts)
                                # prompt encoding
                                print(f'generating {prompt} image')
                                text_encoder_hidden_states = self.deadiff.cond_stage_model.encode(prompt)
                                encoder_hidden_states = [[ref_style, text_encoder_hidden_states], [ref_content, text_encoder_hidden_states]]

                                # encoder_hidden_states = [[encoder_hidden_states_style, text_encoder_hidden_states_style],
                                #                          [encoder_hidden_states_content, text_encoder_hidden_states_content]]

                                shape = [4, 512 // 8, 512 // 8]  # f是down sampling factor
                                samples_ddim, _ = self.sampler.sample(S=50,
                                                                 conditioning=encoder_hidden_states,
                                                                 batch_size=ref_style.shape[0],
                                                                 shape=shape,
                                                                 verbose=False,
                                                                 unconditional_guidance_scale=7.5,
                                                                 unconditional_conditioning=[uc, uc])
                                x_samples_ddim = self.deadiff.decode_first_stage(samples_ddim)
                                x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                                x_samples_ddim = accelerator.gather(x_samples_ddim)

                                count = 0
                                if accelerator.is_main_process:
                                    for idx, x_sample in enumerate(x_samples_ddim):
                                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                                        # base_count = len(os.listdir(f'/home/dell/workspace/wzc/myunlearn/utils/data/Gen/{self.dataloader.name}/{prompts[idx]}'))
                                        base_count = len(os.listdir(f'{gen_path}/{prompt}'))

                                        Image.fromarray(x_sample.astype(np.uint8)).save(
                                            os.path.join(gen_path, f'{prompt}', f'{base_count}.png'))


class Client(UnlearnBasicClient):
    def __init__(self, option, id, model=None):
        super(Client, self).__init__(option, id, model)
        self.K = 5
        self.beta = 0.0
        self.deadiff = None

    def train(self, ):
        # initial_train_model = copy.deepcopy(model)
        self.model.train()
        total_loss = 0.0
        optimizer = self.get_optimizer(self.model)
        for e in range(self.epochs):
            # for step, (batch_x, batch_y) in enumerate(self.train_data):
            for batch_id, batch_data in enumerate(self.train_data):
                batch_x, batch_y = batch_data['image'], batch_data['label']
                self.model.zero_grad()
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)

                local_content_features = None


                # rse_only -> c: [style, content], not rse_only -> c: [[encoder_hidden_states_style, text_encoder_hidden_states_style], [encoder_hidden_states_content, text_encoder_hidden_states_content]]
                c = self.deadiff.get_learned_conditioning({
                    'target_text': 'A cat.',
                    'inp_image': 2 * (batch_x - 0.5),
                    'subject_text': ['content'] * len(batch_y),  # subject_text是规定抽取content还是style
                }, rse_only=True)
                local_content_features = c[1]
                del c

                outputs = self.model(batch_x, local_content_features)
                loss = self.criterion(outputs, batch_y)

                if self.unlearn and self.stage == 'Unlearn':
                    loss *= -1.0
                loss.backward()
                if self.unlearn and self.stage == 'Unlearn':
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # max_norm / clip_value two options

                optimizer.step()
                batch_mean_loss = loss.item()
                total_loss += batch_mean_loss * len(batch_y)
        del optimizer
        return total_loss / (self.datavol * self.epochs)

    def get_class_proto(self, deadiff):
        local_style_features = []
        local_content_features = []
        feature_center = {}
        precision_scope = autocast
        with torch.no_grad():
            with precision_scope("cuda"):
                with deadiff.ema_scope():
                    for batch_id, batch_data in enumerate(self.train_data):
                        batch_x, batch_y = batch_data['image'], batch_data['label']
                        batch_d = batch_data['domain']
                        batch_x = self.data_to_device(batch_x, device=self.device)
                        accelerator = accelerate.Accelerator()
                        all_samples = list()

                        # rse_only -> c: [style, content], not rse_only -> c: [[encoder_hidden_states_style, text_encoder_hidden_states_style], [encoder_hidden_states_content, text_encoder_hidden_states_content]]
                        c = deadiff.get_learned_conditioning({
                            'target_text': 'A cat.',
                            'inp_image': 2*(batch_x - 0.5),
                            'subject_text': [['style', 'content']] * len(batch_y),  # subject_text是规定抽取content还是style
                        }, rse_only=True)
                        local_style_features.append(c[0])
                        local_content_features.append(c[1])
                        del c

        local_style_feature = torch.cat(local_style_features, dim=0).cpu()
        local_style_feature = local_style_feature.view(-1, local_style_feature.shape[1]*local_style_feature.shape[2])
        # FINCH
        # local_mean, local_features = FINCH_clustering(local_style_feature, )
        local_mean, local_features = k_means_clustering(local_style_feature, clusters_num=5)

        local_tensor = torch.from_numpy(local_features.astype(np.float32)).view(-1, 16, 768).to('cpu')
        local_mean_tensor = torch.from_numpy(local_mean.astype(np.float32)).view(16, 768).unsqueeze(0).to('cpu')

        # 保存content
        return [local_tensor, torch.zeros_like(local_tensor)], [local_mean_tensor, torch.zeros_like(local_mean_tensor)]

    def unpack(self, received_pkg):
        self.current_rounds = received_pkg['current_rounds']
        self.lr = received_pkg['lr']
        self.momentum = received_pkg['momentum']
        self.weight_decay = received_pkg['weight_decay']
        self.stage = received_pkg['stage']
        # self.g_s = received_pkg['g_s']
        self.model = received_pkg['model']
        del received_pkg






    def test(self, model=None, dataflag='test'):
        test_model = model if model is not None else self.model
        test_model.eval()

        if dataflag == 'train':
            dataset = self.train_data
            datavol = self.datavol
        else:
            dataset = self.test_data
            datavol = self.test_datavol
        total_loss = 0.0
        num_correct = 0
        local_metric = {}
        correct_by_domain = {}
        debug_domain = {}
        with torch.no_grad():
            # for batch_id, (batch_x, batch_y) in enumerate(dataset):
            for batch_id, batch_data in enumerate(dataset):
                batch_x, batch_y, batch_d = batch_data['image'], batch_data['label'], batch_data['domain']
                if self.unlearn and self.bd:
                    batch_x, batch_y = self.bd_maker.add_backdoor(batch_x, batch_y)
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)

                c = self.deadiff.get_learned_conditioning({
                    'target_text': 'A cat.',
                    'inp_image': 2 * (batch_x - 0.5),
                    'subject_text': ['content'] * len(batch_y),  # subject_text是规定抽取content还是style
                }, rse_only=True)
                local_content_features = c[1]
                del c

                outputs = test_model(batch_x, local_content_features)
                batch_mean_loss = self.criterion(outputs, batch_y).item()
                y_pred = outputs.data.max(1, keepdim=True)[1]
                for dn in torch.unique(batch_d):
                    dn = dn.item()
                    domain_indices = torch.where(batch_d == dn)
                    if dn in debug_domain:
                        debug_domain[dn] += len(domain_indices[0])
                    else:
                        debug_domain[dn] = len(domain_indices[0])

                    if dn in correct_by_domain:
                        correct_by_domain[dn] += y_pred[domain_indices].eq(batch_y[domain_indices].view_as(y_pred[domain_indices])).long().cpu().sum().item()
                    else:
                        correct_by_domain[dn] = y_pred[domain_indices].eq(batch_y[domain_indices].view_as(y_pred[domain_indices])).long().cpu().sum().item()
                correct = y_pred.eq(batch_y.data.view_as(y_pred)).long().cpu().sum()
                num_correct += correct.item()
                total_loss += batch_mean_loss * len(batch_y)
            if not self.unlearn:
                local_metric.update({'retain_accuracy': 100 * num_correct / datavol, 'retain_loss': total_loss / datavol,
                                     'domain_metric': correct_by_domain})
            else:
                local_metric.update({'Backdoor_accuracy': 100 * num_correct / datavol,
                                    'Backdoor_loss': total_loss / datavol, 'domain_metric': correct_by_domain})
        if self.unlearn:
            # 统计unlearn memory acc 用训练集做指标的
            BD_correct = 0
            BD_loss = 0.0
            correct_by_domain = {}
            with torch.no_grad():
                # for batch_id, (batch_x, batch_y) in enumerate(self.UM_test_data):
                for batch_id, batch_data in enumerate(self.UM_test_data):
                    batch_x, batch_y, batch_d = batch_data['image'], batch_data['label'], batch_data['domain']
                    if self.unlearn and self.bd:
                        batch_x, batch_y = self.bd_maker.add_backdoor(batch_x, batch_y)
                    batch_x = self.data_to_device(batch_x, device=self.device)
                    batch_y = self.data_to_device(batch_y, device=self.device)

                    c = self.deadiff.get_learned_conditioning({
                        'target_text': 'A cat.',
                        'inp_image': 2 * (batch_x - 0.5),
                        'subject_text': ['content'] * len(batch_y),  # subject_text是规定抽取content还是style
                    }, rse_only=True)
                    local_content_features = c[1]
                    del c

                    outputs = test_model(batch_x, local_content_features)
                    # outputs = test_model(batch_x)
                    batch_mean_loss = self.criterion(outputs, batch_y).item()
                    y_pred = outputs.data.max(1, keepdim=True)[1]
                    for dn in torch.unique(batch_d):
                        dn = dn.item()
                        domain_indices = torch.where(batch_d == dn)
                        if dn in correct_by_domain:
                            correct_by_domain[dn] += y_pred[domain_indices].eq(
                                batch_y[domain_indices].view_as(y_pred[domain_indices])).long().cpu().sum().item()
                        else:
                            correct_by_domain[dn] = y_pred[domain_indices].eq(
                                batch_y[domain_indices].view_as(y_pred[domain_indices])).long().cpu().sum().item()
                    correct = y_pred.eq(batch_y.data.view_as(y_pred)).long().cpu().sum()
                    BD_correct += correct.item()
                    BD_loss += batch_mean_loss * len(batch_y)
                local_metric.update({'Unlearn_Memory_accuracy': 100 * BD_correct / self.UM_test_datavol,
                                     'Unlearn_Memory_loss': BD_loss / self.UM_test_datavol, 'UM_domain_metric': correct_by_domain})
        return local_metric