import torch
import copy
import random
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torch import nn
from collections import defaultdict


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)


class DatasetSplit_domain(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label, domain = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label), torch.tensor(domain)


class LocalUpdate(object):
    def __init__(self, args, train_data, idxs, client_index, device, model, logger, user_base_labels):
        self.args = args
        self.testloader_cache = None
        if self.args.data_num == "few-shot":
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            self.trainloader = self.get_loader(train_data, list(idxs))
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.model = model
        self.logger = logger
        self.base_labels = user_base_labels

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.model.named_parameters():
                if i in model_params.keys():
                    p.data.copy_(model_params[i].data)

    def get_loader(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have labels")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(self.args.shot_num, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def local_train_split(self, text_features, flag):
        local_img_proj_dict, label_dict = defaultdict(list), defaultdict(list)
        batch_idx, uploaded_num = 0, 0
        all_img_projs, all_labels = [], []

        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        ######################### Supervised Fine-Tuning for Local Image Encoder #########################
        if (flag == False):
            for iter in range(self.args.local_ep):
                total, correct = 0.0, 0.0
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=1e-3)

                for images, labels in tqdm(self.trainloader, disable=True):
                    images, labels = images.to(self.device), labels.to(self.device)
                    _, img_proj = self.model(images)

                    img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                    logits = scale * img_proj @ text_features.t()
                    loss = self.criterion_CE(logits, labels)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    total += labels.size(0)
                    preds = logits.argmax(dim=1)
                    correct += (preds == labels).sum().item()

                print('[SFT] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        ######################### Reinforcement Fine-Tuning for Local Image Encoder #########################
        else:
            reference_model = copy.deepcopy(self.model)
            save_name = f"./{self.args.dataname}_{self.args.IID}_SFT_final1.pt"
            model_params = torch.load(save_name, map_location=self.device, weights_only=True)
            with torch.no_grad():
                for name, param in reference_model.named_parameters():
                    if name in model_params:
                        new_param = 0.5 * model_params[name].data + param * 0.5
                        param.data.copy_(new_param)
            reference_model.eval()
            for p in reference_model.parameters():
                p.requires_grad = False

            RL_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=1e-3)
            RL_total, RL_correct = 0.0, 0.0
            for images, labels in tqdm(self.trainloader, disable=True):
                images, labels = images.to(self.device), labels.to(self.device)
                all_rewards, all_old_log_probs = [], []

                with torch.no_grad():
                    _, ref_img_proj = reference_model(images)
                    ref_img_proj = ref_img_proj / ref_img_proj.norm(dim=1, keepdim=True)
                    ref_logits = scale * ref_img_proj @ text_features.t()
                    ref_log_probs = ref_logits.log_softmax(dim=-1)
                    ref_log_probs = torch.gather(ref_log_probs, dim=1, index=labels.unsqueeze(1)).squeeze(1)

                    for _ in range(self.args.GRPO_sampling_epochs):
                        _, img_proj = self.model(images)
                        img_proj = img_proj + torch.randn_like(img_proj) * self.args.noise_std
                        img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                        logits = scale * img_proj @ text_features.t()

                        old_log_probs = logits.log_softmax(dim=-1)
                        old_log_probs = torch.gather(old_log_probs, dim=1, index=labels.unsqueeze(1)).squeeze(1)

                        preds = logits.argmax(dim=1)
                        rewards = (preds == labels).float()
                        all_rewards.append(rewards)
                        all_old_log_probs.append(old_log_probs)

                    all_rewards = torch.stack(all_rewards)
                    all_old_log_probs = torch.stack(all_old_log_probs)
                    reward_mean = all_rewards.mean(dim=0, keepdim=True)
                    reward_std = all_rewards.std(dim=0, keepdim=True) + 1e-8
                    adv = (all_rewards - reward_mean) / reward_std

                for j in range(self.args.GRPO_epochs):
                    _, img_proj = self.model(images)
                    img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                    logits = scale * img_proj @ text_features.t()
                    new_log_probs = logits.log_softmax(dim=-1)
                    new_log_probs = torch.gather(new_log_probs, dim=1, index=labels.unsqueeze(1)).squeeze(1)

                    loss = 0
                    for i in range(self.args.GRPO_sampling_epochs):
                        ratio = torch.exp(new_log_probs - all_old_log_probs[i])
                        surr1 = ratio * adv[i]
                        surr2 = torch.clamp(ratio, 1 - self.args.clip_eps, 1 + self.args.clip_eps) * adv[i]
                        policy_loss = torch.min(surr1, surr2)
                        kl_div = torch.exp(ref_log_probs - new_log_probs) - (ref_log_probs - new_log_probs) - 1

                        loss += -(policy_loss - kl_div * self.args.kl_coef).mean()

                    loss = loss / self.args.GRPO_sampling_epochs
                    RL_optimizer.zero_grad()
                    loss.backward()
                    RL_optimizer.step()

                    if ((j + 1) == self.args.GRPO_epochs):
                        RL_total += labels.size(0)
                        preds = logits.argmax(dim=1)
                        RL_correct += (preds == labels).sum().item()

            print('[GRPO] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * RL_correct / RL_total))

        ######################### Extracting Image Features #########################
        updated_params = {name: param for name, param in self.model.named_parameters() if param.requires_grad}

        with torch.no_grad():
            final_correct, final_total = 0.0, 0.0
            for images, labels in tqdm(self.trainloader, disable=True):
                images, labels = images.to(self.device), labels.to(self.device)
                _, img_proj = self.model(images)
                img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)

                logits = scale * img_proj @ text_features.t()
                preds = logits.argmax(dim=1)
                final_correct += (preds == labels).sum().item()
                final_total += labels.size(0)

                all_img_projs.append(img_proj.cpu().detach())
                all_labels.append(labels.long().cpu().detach())

            all_img_projs = torch.cat(all_img_projs, dim=0)
            all_labels = torch.cat(all_labels, dim=0)

            selected_indices = []
            unique_labels = torch.unique(all_labels)
            for label in unique_labels:
                label_indices = (all_labels == label).nonzero(as_tuple=True)[0]
                num_label_samples = len(label_indices)
                num_to_select = max(1, int(num_label_samples * self.args.random_label_ratio))
                num_to_select = min(num_to_select, num_label_samples)

                selected = label_indices[torch.randperm(num_label_samples)[:num_to_select]]
                selected_indices.append(selected)

            selected_indices = torch.cat(selected_indices)
            uploaded_num = len(selected_indices)
            upload_features = all_img_projs[selected_indices]
            upload_labels = all_labels[selected_indices]

            for i in range(len(upload_features)):
                local_img_proj_dict[batch_idx].append(upload_features[i].numpy())
                label_dict[batch_idx].append(upload_labels[i].numpy())
                if ((i + 1) % self.args.local_bs == 0):
                    batch_idx += 1
        print(f"The number of uploaded image features is {uploaded_num}")
        train_acc = final_correct / final_total

        return updated_params, local_img_proj_dict, label_dict, train_acc


class LocalUpdate_domain(object):
    def __init__(self, args, train_data, idxs, client_index, device, model, logger):
        self.args = args
        self.testloader_cache = None
        if self.args.data_num == "few-shot":
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            self.trainloader = self.get_loader(train_data, list(idxs))
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.model = model
        self.logger = logger
        self.domain_idx = self.infer_client_domain()

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.model.named_parameters():
                if i in model_params.keys():
                    p.data.copy_(model_params[i].data)

    def get_loader(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit_domain(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have labels")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(16, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit_domain(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def infer_client_domain(self, ):
        domain_set = set()
        for _, _, domains in self.trainloader:
            domain_set.update(domains.tolist())
        if len(domain_set) != 1:
            raise ValueError(f"[Client {self.client_index}] expected exactly one domain, but got {domain_set}")
        domain_idx = next(iter(domain_set))
        print(f"client {self.client_index} domain idx: {domain_idx}")
        return int(domain_idx)

    def local_train_split(self, text_features, flag):
        local_img_proj_dict, label_dict = defaultdict(list), defaultdict(list)
        batch_idx, uploaded_num = 0, 0
        all_img_projs, all_labels = [], []

        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        ######################### Supervised Fine-Tuning for Local Image Encoder #########################
        if (flag == False):
            for iter in range(self.args.local_ep):
                total, correct = 0.0, 0.0
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=1e-3)

                for images, labels, _ in tqdm(self.trainloader, disable=True):
                    images, labels = images.to(self.device), labels.to(self.device)
                    _, img_proj = self.model(images)

                    img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                    logits = scale * img_proj @ text_features.t()
                    loss = self.criterion_CE(logits, labels)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    total += labels.size(0)
                    preds = logits.argmax(dim=1)
                    correct += (preds == labels).sum().item()

                print('[SFT] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        ######################### Reinforcement Fine-Tuning for Local Image Encoder #########################
        else:
            reference_model = copy.deepcopy(self.model)
            save_name = f"./{self.args.dataname}_SFT_final1.pt"
            model_params = torch.load(save_name, map_location=self.device, weights_only=True)
            with torch.no_grad():
                for name, param in reference_model.named_parameters():
                    if name in model_params:
                        new_param = 0.5 * model_params[name].data + param * 0.5
                        param.data.copy_(new_param)
            reference_model.eval()
            for p in reference_model.parameters():
                p.requires_grad = False

            RL_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=1e-3)
            RL_total, RL_correct = 0.0, 0.0
            for images, labels, _ in tqdm(self.trainloader, disable=True):
                images, labels = images.to(self.device), labels.to(self.device)
                all_rewards, all_old_log_probs = [], []

                with torch.no_grad():
                    _, ref_img_proj = reference_model(images)
                    ref_img_proj = ref_img_proj / ref_img_proj.norm(dim=1, keepdim=True)
                    ref_logits = scale * ref_img_proj @ text_features.t()
                    ref_log_probs = ref_logits.log_softmax(dim=-1)
                    ref_log_probs = torch.gather(ref_log_probs, dim=1, index=labels.unsqueeze(1)).squeeze(1)

                    for _ in range(self.args.GRPO_sampling_epochs):
                        _, img_proj = self.model(images)
                        img_proj = img_proj + torch.randn_like(img_proj) * self.args.noise_std
                        img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                        logits = scale * img_proj @ text_features.t()

                        old_log_probs = logits.log_softmax(dim=-1)
                        old_log_probs = torch.gather(old_log_probs, dim=1, index=labels.unsqueeze(1)).squeeze(1)

                        preds = logits.argmax(dim=1)
                        rewards = (preds == labels).float()
                        all_rewards.append(rewards)
                        all_old_log_probs.append(old_log_probs)

                    all_rewards = torch.stack(all_rewards)
                    all_old_log_probs = torch.stack(all_old_log_probs)
                    reward_mean = all_rewards.mean(dim=0, keepdim=True)
                    reward_std = all_rewards.std(dim=0, keepdim=True) + 1e-8
                    adv = (all_rewards - reward_mean) / reward_std

                for j in range(self.args.GRPO_epochs):
                    _, img_proj = self.model(images)
                    img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                    logits = scale * img_proj @ text_features.t()
                    new_log_probs = logits.log_softmax(dim=-1)
                    new_log_probs = torch.gather(new_log_probs, dim=1, index=labels.unsqueeze(1)).squeeze(1)

                    loss = 0
                    for i in range(self.args.GRPO_sampling_epochs):
                        ratio = torch.exp(new_log_probs - all_old_log_probs[i])
                        surr1 = ratio * adv[i]
                        surr2 = torch.clamp(ratio, 1 - self.args.clip_eps, 1 + self.args.clip_eps) * adv[i]
                        policy_loss = torch.min(surr1, surr2)
                        kl_div = torch.exp(ref_log_probs - new_log_probs) - (ref_log_probs - new_log_probs) - 1

                        loss += -(policy_loss - kl_div * self.args.kl_coef).mean()

                    loss = loss / self.args.GRPO_sampling_epochs
                    RL_optimizer.zero_grad()
                    loss.backward()
                    RL_optimizer.step()

                    if ((j + 1) == self.args.GRPO_epochs):
                        RL_total += labels.size(0)
                        preds = logits.argmax(dim=1)
                        RL_correct += (preds == labels).sum().item()

            print('[GRPO] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * RL_correct / RL_total))

        ######################### Extracting Image Features #########################
        updated_params = {name: param for name, param in self.model.named_parameters() if param.requires_grad}

        with torch.no_grad():
            final_correct, final_total = 0.0, 0.0
            for images, labels, _ in tqdm(self.trainloader, disable=True):
                images, labels = images.to(self.device), labels.to(self.device)
                _, img_proj = self.model(images)
                img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)

                logits = scale * img_proj @ text_features.t()
                preds = logits.argmax(dim=1)
                final_correct += (preds == labels).sum().item()
                final_total += labels.size(0)

                all_img_projs.append(img_proj.cpu().detach())
                all_labels.append(labels.long().cpu().detach())

            all_img_projs = torch.cat(all_img_projs, dim=0)
            all_labels = torch.cat(all_labels, dim=0)

            selected_indices = []
            unique_labels = torch.unique(all_labels)
            for label in unique_labels:
                label_indices = (all_labels == label).nonzero(as_tuple=True)[0]
                num_label_samples = len(label_indices)
                num_to_select = max(1, int(num_label_samples * self.args.random_label_ratio))
                num_to_select = min(num_to_select, num_label_samples)

                selected = label_indices[torch.randperm(num_label_samples)[:num_to_select]]
                selected_indices.append(selected)

            selected_indices = torch.cat(selected_indices)
            uploaded_num = len(selected_indices)
            upload_features = all_img_projs[selected_indices]
            upload_labels = all_labels[selected_indices]

            for i in range(len(upload_features)):
                local_img_proj_dict[batch_idx].append(upload_features[i].numpy())
                label_dict[batch_idx].append(upload_labels[i].numpy())
                if ((i + 1) % self.args.local_bs == 0):
                    batch_idx += 1
        print(f"The number of uploaded image features is {uploaded_num}")
        train_acc = final_correct / final_total

        return updated_params, local_img_proj_dict, label_dict, train_acc
