import copy
import torch
import numpy as np
from diffusers import DDPMScheduler, StableDiffusionPipeline
from PIL import Image
import os
import torch.nn as nn
import torch.nn.functional as F
from diffusers.optimization import get_scheduler
from typing import Optional, Union, List
from peft import LoraConfig
from .utils import get_transform, PROMPTS, CIFAR_10_CLASSES, CIFAR_100_CLASSES, TINY_IMAGENET_CLASSES, _set_trainable_adapter, _adapter_params, _get_class_name, check_tensor
from PIL import Image
from torch.optim import AdamW
from collections import defaultdict

GEN_NUM = 10
SUP_NUM_PER_CLI = 10
THRESHOLD_KEEP_SUP = 5000
wtype = torch.float32

class Server:
    def __init__(self, device, clients, global_model, args, test_loader, logger=None):
        self.device = device
        self.args = args
        self.logger = logger
        self.clients = clients
        self.global_model = global_model
        self.total_data = 0
        self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
        self.test_loader = test_loader
        self.cnt = 0
        self.w_avg = {}
        self.current_idx = [0] * self.args.num_users
        self.idx_map = {}
        self.idx_map_reverse = {}
        self.global_idx = 0
        self.pipe = StableDiffusionPipeline.from_pretrained("~/models/stable-diffusion-v1-4", torch_dtype=wtype).to(self.device)
        self.pipe.safety_checker = None
        self.noise_scheduler = DDPMScheduler.from_config(self.pipe.scheduler.config)
        
        self.client_origin_data_num = [len(client.trainloader.dataset) for client in self.clients]
        print(self.client_origin_data_num)
        # Generate data part
        if self.args.supply_alpha != 0:
            self.gen_list = self.get_gen_list(40, self.args.supply_alpha).tolist()
        else:
            self.gen_list = [10 for _ in range(self.args.max_supply_num // 10)]
        if len(self.gen_list) < self.args.epochs:
            self.gen_list.extend([0] * (self.args.epochs - len(self.gen_list)))
        self.gen_list = [max(10, value) for value in self.gen_list]
        print(self.gen_list)

        # DPO Diffusion Adapter part
        self.pipe.unet.requires_grad_(False)
        self.pipe.text_encoder.requires_grad_(False)
        self.pipe.vae.requires_grad_(False)

        self.ref_unet = copy.deepcopy(self.pipe.unet).to(self.device)
        for p in self.ref_unet.parameters():
            p.requires_grad_(False)
        self.ref_unet.eval()

        
        lora_cfg = LoraConfig(
            r=4, lora_alpha=4, lora_dropout=0.0, bias="none",
            target_modules=["to_q","to_k","to_v","to_out.0"],
            inference_mode=False,
            init_lora_weights="gaussian",
        )
        for i in range(len(self.clients)):
            self.pipe.unet.add_adapter(lora_cfg, adapter_name=f"adapter_{i}")
        self.cli_opt, self.cli_sch = {}, {}
        for i in range(len(self.clients)):
            _set_trainable_adapter(self.pipe, f"adapter_{i}")
            params = _adapter_params(self.pipe, f"adapter_{i}")
            opt = AdamW(params, lr=5e-5)
            self.cli_opt[i] = opt
            self.cli_sch[i] = get_scheduler("constant", optimizer=opt, num_warmup_steps=0)
        _set_trainable_adapter(self.pipe, "")

        self.sd_grad_accum = 4
        self.sd_clip_grad = 1.0
        self.sd_train_steps_per_round = 20
        self.dpo_beta = 1.0
        self.min_reward_gap = 0.05


    def get_gen_list(self, sup_epoch_num, alpha_mild):
        total_data_new = self.args.max_supply_num
        t_values_new = np.arange(1, sup_epoch_num + 1)
        coefficients_mild = 1 / (t_values_new ** alpha_mild)
        coefficients_mild_normalized = coefficients_mild / np.sum(coefficients_mild)
        D_values_mild = total_data_new * coefficients_mild_normalized
        return D_values_mild

    def aggreate_model(self, avg_keys: list):
        self.total_data = 0
        for client in self.clients:
            self.total_data += client.data_num 
        self.w_avg = {}
        for avg_key in avg_keys:
            w = 0
            for client in self.clients:
                k = client.data_num / self.total_data
                w = w + k * client.local_model.state_dict()[avg_key]
            self.w_avg[avg_key] = copy.deepcopy(w)
        self.global_model.load_state_dict(self.w_avg, strict=False)
        for i, client in enumerate(self.clients):
            self.log_info(f"client {i} data num: {client.data_num}")

    def sync_model(self, avg_keys: list):
        for client in self.clients:
            client.local_model.load_state_dict(self.w_avg, strict=False)

    
    def generate_image(self, i: int) -> Image.Image:
        if self.args.dataset == 'cifar-10':
            prompt = PROMPTS[(i // 10 % len(PROMPTS))].replace("{class}", CIFAR_10_CLASSES[i % 10])
            img = self.pipe(prompt=prompt).images[0]
            img = img.resize((32, 32))
        elif self.args.dataset == 'cifar-100':
            prompt = PROMPTS[(i // 100 % len(PROMPTS))].replace("{class}", CIFAR_100_CLASSES[i % 100])
            img = self.pipe(prompt=prompt).images[0]
            img = img.resize((32, 32))
        elif self.args.dataset == 'tiny-imagenet':
            prompt = PROMPTS[(i // 200 % len(PROMPTS))].replace("{class}", TINY_IMAGENET_CLASSES[i % 200])
            img = self.pipe(prompt=prompt).images[0]
            img = img.resize((64, 64))
        else:
            raise NotImplementedError
        return img

    def compute_reward(self, global_output, client_output, label_t):
        pg = torch.softmax(global_output, dim=1)
        pc = torch.softmax(client_output, dim=1)
        g_conf = pg[0, label_t.item()]
        c_conf = pc[0, label_t.item()]
        kl = F.kl_div(torch.log(pg + 1e-8), pc, reduction="batchmean")
        disagree = (pg.argmax(dim=1) != pc.argmax(dim=1)).float()
        r = self.args.rw_alpha * g_conf - self.args.rw_beta * c_conf + self.args.rw_gamma * kl + self.args.rw_delta * disagree
        return r.item()

    
    def supply_data(self, epoch: int):
        supply_idx_list = [set() for _ in range(len(self.clients))]

        file_path = os.path.join(self.args.gen_data_dir, f'{self.args.dataset}', f'{self.args.alpha}', f'{self.args.seed}')
        os.makedirs(file_path, exist_ok=True)
        dataset = self.clients[0].trainloader.dataset.dataset
        
        self.global_model.eval()
        for client in self.clients:
            client.local_model.eval()
        
        gen_num = self.gen_list[epoch]
        for i, client in enumerate(self.clients):
            self.pipe.unet.set_adapter(f"adapter_{i}")
            feedback = []
            idx = self.current_idx[i]
            cnt = 0
            max_supply_num = self.client_origin_data_num[i] + self.args.max_supply_num
            while cnt < gen_num and idx < THRESHOLD_KEEP_SUP and len(self.clients[i].trainloader.dataset) < max_supply_num:
                img_path = os.path.join(file_path, f'{idx}.png')
                if not os.path.exists(img_path):
                    img = self.generate_image(idx)
                    img.save(img_path)
                    print(f'Generete {idx} image finished.')

                label = idx % self.args.num_classes
                img = Image.open(img_path)
                if idx not in self.idx_map:
                    dataset.add_data(img, label)
                    self.idx_map[idx] = len(dataset) - 1
                    self.idx_map_reverse[len(dataset) - 1] = idx

                transform = get_transform(self.args.dataset)
                img = transform(img)
                img = img.unsqueeze(0).to(self.args.device)
                label = torch.tensor([label]).long().to(self.args.device)
                global_output = self.global_model(img)
                g_loss = self.criterion(global_output, label)
                g_logits = torch.softmax(global_output, dim=1)
                g_conf = g_logits[0, label.item()].item()

                client_output = client.local_model(img)
                c_loss = client.criterion(client_output, label)
                c_logits = torch.softmax(client_output, dim=1)
                c_conf = c_logits[0, label.item()].item()

                reward = self.compute_reward(global_output, client_output, label)
                feedback.append((img_path, reward, label.item()))

                if g_conf > 1 / self.args.num_classes and c_loss > g_loss and len(self.clients[i].trainloader.dataset) < max_supply_num:
                    supply_idx_list[i].add(self.idx_map[idx])
                    cnt += 1
                idx += 1
                
            self.current_idx[i] = idx
            if len(self.clients[i].trainloader.dataset) < max_supply_num and epoch < 3:
                self.online_align_sd(feedback, i)

        for i, st in enumerate(supply_idx_list):
            self.log_info(f"{i}: {list(st)}")

        for i, client in enumerate(self.clients):
            client.add_generate_datas(list(supply_idx_list[i]))
        return supply_idx_list
    
    def online_align_sd(self, feedback, index):
        adapter_name = f"adapter_{index}"
        self.pipe.unet.set_adapter(adapter_name)
        _set_trainable_adapter(self.pipe, adapter_name)
        if len(feedback) < 2:
            return
        
        buckets = defaultdict(list)
        pos_dict, neg_dict, labels = {}, {}, []
        total_pairs = 0
        for path, reward, label in feedback:
            if reward > 0:
                buckets[label].append((path, reward))
                self.log_info(f"{label}, {path}, {reward}")
        for label, lst in buckets.items():
            lst.sort(key=lambda x: x[1], reverse=True)
            m = len(lst) // 2
            if m == 0:
                continue
            pos_mean = float(np.mean([r for _, r in lst[:m]])) if m > 0 else 0.0
            neg_mean = float(np.mean([r for _, r in lst[-m:]])) if m > 0 else 0.0
            if pos_mean - neg_mean < self.min_reward_gap:
                self.log_info(
                    f"[DPO][client_{index}] skip label={label} "
                    f"m={m} pos_mean={pos_mean:.4f} neg_mean={neg_mean:.4f} gap={(pos_mean - neg_mean):.4f} "
                    f"(min_gap={self.min_reward_gap:.4f})"
                )
                continue

            pos_dict[label] = lst[:m]
            neg_dict[label] = lst[-m:]
            labels.append(label)
            total_pairs += m
        if len(labels) == 0:
            print("Insufficient per-label pairs for DPO")
            return

        
        self.pipe.unet.train()
        opt = self.cli_opt[index]
        scheduler = self.cli_sch[index]
        self.pipe.unet.train()

        steps, acc = 0, 0
        opt.zero_grad(set_to_none=True)
        while steps < self.sd_train_steps_per_round:
            lens = np.array([len(pos_dict[lbl]) for lbl in labels], dtype=np.float64)
            probs = lens / lens.sum()
            lbl = int(np.random.choice(labels, p=probs))
            m = len(pos_dict[lbl])
            bs = min(8, m)
            self.log_info(f"{lbl}, {bs}")
            idxs = np.arange(m)
            a = np.random.choice(idxs, size=bs, replace=True)
            b = np.random.choice(idxs, size=bs, replace=True)
            p_paths = [pos_dict[lbl][j][0] for j in a]
            n_paths = [neg_dict[lbl][k][0] for k in b]
            p_imgs, n_imgs = [], []
            for pth in p_paths:
                p_imgs.append(Image.open(pth).convert("RGB"))
            for nth in n_paths:
                n_imgs.append(Image.open(nth).convert("RGB"))
            p_px = self.pipe.image_processor.preprocess(p_imgs, height=self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor, width=self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor).to(self.device, dtype=wtype)
            n_px = self.pipe.image_processor.preprocess(n_imgs, height=self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor, width=self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor).to(self.device, dtype=wtype)
            with torch.no_grad():
                p_lat = self.pipe.vae.encode(p_px).latent_dist.sample() * self.pipe.vae.config.scaling_factor
                n_lat = self.pipe.vae.encode(n_px).latent_dist.sample() * self.pipe.vae.config.scaling_factor
            # p_lat = p_lat.to(dtype=wtype); n_lat = n_lat.to(dtype=wtype)
            noise_p = torch.randn_like(p_lat, dtype=wtype)
            noise_n = torch.randn_like(n_lat, dtype=wtype)
            num_t = self.noise_scheduler.config.num_train_timesteps
            t_p = torch.randint(0, num_t, (p_lat.shape[0],), device=p_lat.device).long()
            t_n = torch.randint(0, num_t, (n_lat.shape[0],), device=n_lat.device).long()
            p_noisy = self.noise_scheduler.add_noise(p_lat, noise_p, t_p)
            n_noisy = self.noise_scheduler.add_noise(n_lat, noise_n, t_n)
            text_inputs = self.pipe.tokenizer(f"an image of a {_get_class_name(self.args.dataset, lbl)}", padding="max_length", max_length=self.pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt")
            text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
            with torch.no_grad():
                enc = self.pipe.text_encoder(**text_inputs)[0] 
            if enc.shape[0] != p_noisy.shape[0]:
                enc = enc.repeat(p_noisy.shape[0], 1, 1)
            pred_p = self.pipe.unet(p_noisy, t_p, enc).sample
            pred_n = self.pipe.unet(n_noisy, t_n, enc).sample
            with torch.no_grad():
                ref_pred_p = self.ref_unet(p_noisy, t_p, enc).sample
                ref_pred_n = self.ref_unet(n_noisy, t_n, enc).sample
            mse_p = ((pred_p.float() - noise_p.float()) ** 2).mean(dim=[1,2,3])
            mse_n = ((pred_n.float() - noise_n.float()) ** 2).mean(dim=[1,2,3])
            mse_p_ref = ((ref_pred_p.float() - noise_p.float()) ** 2).mean(dim=[1,2,3])
            mse_n_ref = ((ref_pred_n.float() - noise_n.float()) ** 2).mean(dim=[1,2,3])
            # check_tensor("mse_p", mse_p)
            # check_tensor("mse_n", mse_n)
            # check_tensor("mse_p_ref", mse_p_ref)
            # check_tensor("mse_n_ref", mse_n_ref)


            delta = (mse_n - mse_p)
            delta_ref = (mse_n_ref - mse_p_ref)
            tau = 1e-4
            use_ipo_mask = ((delta - delta_ref).abs() < tau)
            logits_dpo = self.dpo_beta * (delta - delta_ref)
            logits_ipo = self.dpo_beta * delta
            logits = torch.where(use_ipo_mask, logits_ipo, logits_dpo)
            logits = torch.clamp(logits, -20.0, 20.0)
            loss = -torch.log(torch.sigmoid(logits)).mean()
            loss.backward()
            self.log_info(f"{acc}: {use_ipo_mask}")

            acc += 1
            if acc % self.sd_grad_accum == 0:
                params_to_clip = [p for p in self.pipe.unet.parameters() if p.requires_grad and p.grad is not None]
                torch.nn.utils.clip_grad_norm_(params_to_clip, self.sd_clip_grad)
                opt.step()
                opt.zero_grad()
                scheduler.step()
                steps += 1
                self.log_info(
                f"[DPO][client_{index}] step {steps} "
                    f"label={lbl} batch={bs} loss={loss.item():.6f} "
                    f"logits_mean={logits.mean().item():.3f} "
                    f"logits_min={logits.min().item():.3f} "
                    f"logits_max={logits.max().item():.3f}"
                )
        _set_trainable_adapter(self.pipe, "")
        self.pipe.unet.eval()

    def log_info(self, message: str):
        print(message)
        self.logger.info(message)

    def evaluate(self):
        self.global_model.eval()
        correct = 0
        total = 0
        total_loss = 0

        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.global_model(images)
                loss = self.criterion(outputs, labels)
                total_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        average_loss = total_loss / total
        return accuracy, average_loss
    
    
