from collections import OrderedDict
from copy import deepcopy
from typing import Any

import torch
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Subset,SubsetRandomSampler

import numpy as np
import math
from data.utils.datasets import BaseDataset
from src.utils.functional import evaluate_model, get_optimal_cuda_device
from src.utils.metrics import Metrics
from src.utils.models import DecoupledModel
from src.utils.constants import (
    DATA_MEAN,
    DATA_STD,
    FLBENCH_ROOT,
    LR_SCHEDULERS,
    MODE,
    OPTIMIZERS,
)

import pickle


class FedAvgClient:
    def __init__(
        self,
        model: DecoupledModel,
        optimizer_cls: type[torch.optim.Optimizer],
        lr_scheduler_cls: type[torch.optim.lr_scheduler._LRScheduler],
        args: DictConfig,
        dataset: BaseDataset,
        data_indices: list,
        device: torch.device | None,
        return_diff: bool,
        data_partition=None,  # 添加 data_partition 参数
    ):
        self.client_id: int = None
        self.args = args
        if device is None:
            self.device = get_optimal_cuda_device(use_cuda=self.args.common.use_cuda)
        else:
            self.device = device
        self.dataset = dataset
        self.model = model.to(self.device)
        self.regular_model_params: OrderedDict[str, torch.Tensor]
        self.personal_params_name: list[str] = []
        # self.regular_params_name = list(key for key, _ in self.model.named_parameters())
        self.regular_params_name = list(key for key, _ in self.model.state_dict().items())
        if self.args.common.buffers == "local":
            self.personal_params_name.extend(
                [name for name, _ in self.model.named_buffers()]
            )
        elif self.args.common.buffers == "drop":
            self.init_buffers = deepcopy(OrderedDict(self.model.named_buffers()))

        self.optimizer = optimizer_cls(params=self.model.parameters())
        self.init_optimizer_state = deepcopy(self.optimizer.state_dict())

        self.lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None
        self.init_lr_scheduler_state: dict = None
        self.lr_scheduler_cls = None
        if lr_scheduler_cls is not None:
            self.lr_scheduler_cls = lr_scheduler_cls
            self.lr_scheduler = self.lr_scheduler_cls(optimizer=self.optimizer)
            self.init_lr_scheduler_state = deepcopy(self.lr_scheduler.state_dict())

        # [{"train": [...], "val": [...], "test": [...]}, ...]
        self.data_indices = data_indices
        # Please don't bother with the [0], which is only for avoiding raising runtime error by setting Subset(indices=[]) with `DataLoader(shuffle=True)`
        self.trainset = Subset(self.dataset, indices=[0])
        self.valset = Subset(self.dataset, indices=[])
        self.testset = Subset(self.dataset, indices=[])
        self.trainloader = DataLoader(
            self.trainset, batch_size=self.args.common.batch_size, shuffle=True
        )
        # print(f"Client {self.client_id} dataset size: {len(self.dataset)}")
        # self.train_sampler = SubsetRandomSampler([])
        # self.trainloader = DataLoader(
        #     self.dataset, batch_size=self.args.common.batch_size, sampler=self.train_sampler
        # )
        self.valloader = DataLoader(self.valset, batch_size=self.args.common.batch_size)
        self.testloader = DataLoader(
            self.testset, batch_size=self.args.common.batch_size
        )
        
        
        self.testing = False

        self.local_epoch = self.args.common.local_epoch
        self.criterion = torch.nn.CrossEntropyLoss().to(self.device)

        self.eval_results = {}

        self.return_diff = return_diff
        self.domain = None
        self.domain_info = ""
        
        # 添加数据分区属性以支持域测试
        if data_partition is not None:
            self.data_partition = data_partition
        else:
            # 如果没有传入，尝试从文件加载
            try:
                partition_path = (
                    FLBENCH_ROOT / "data" / self.args.dataset.name / "partition.pkl"
                )
                with open(partition_path, "rb") as f:
                    self.data_partition = pickle.load(f)
            except:
                self.data_partition = None

    def load_data_indices(self):
        """This function is for loading data indices for No.`self.client_id`
        client."""
        self.trainset.indices = self.data_indices[self.client_id]["train"]
        self.valset.indices = self.data_indices[self.client_id]["val"]
        self.testset.indices = self.data_indices[self.client_id]["test"]

    def train_with_eval(self):
        """Wraps `fit()` with `evaluate()` and collect model evaluation
        results.

        A model evaluation results dict: {
                `before`: {...}
                `after`: {...}
                `message`: "..."
            }
            `before` means pre-local-training.
            `after` means post-local-training
        """
        eval_results = {
            "before": {"train": Metrics(), "val": Metrics(), "test": Metrics()},
            "after": {"train": Metrics(), "val": Metrics(), "test": Metrics()},
        }
        eval_results["before"] = self.evaluate()
        if self.local_epoch > 0:
            self.fit()
            eval_results["after"] = self.evaluate()


        eval_msg = []
        if self.args.common.test.client.test:
            for split, color, flag, subset in [
                ["train", "yellow", self.args.common.test.client.train, self.trainset],
                ["val", "green", self.args.common.test.client.val, self.valset],
                ["test", "cyan", self.args.common.test.client.test, self.testset],
            ]:
                if len(subset) > 0 and flag:
                    eval_msg.append(
                        f"client [{self.client_id}]\t"
                        f"{self.domain_info}"
                        f"[{color}]({split}set)[/{color}]\t"
                        f"[red]loss: {eval_results['before'][split].loss:.4f} -> "
                        f"{eval_results['after'][split].loss:.4f}\t[/red]"
                        f"[blue]accuracy: {eval_results['before'][split].accuracy:.2f}% -> {eval_results['after'][split].accuracy:.2f}%[/blue]"
                    )
        # print(eval_msg)
        eval_results["message"] = eval_msg
        self.eval_results = eval_results

    def set_parameters(self, package: dict[str, Any]):

        
        self.client_id = package["client_id"]
        self.local_epoch = package["local_epoch"]
        self.current_epoch = package["current_epoch"]
        
        try:
            import json
            # 尝试从全局访问stats
            dataset_root = (
                FLBENCH_ROOT / "data" / self.args.dataset.name 
            )
            # 加载stats信息
            if (dataset_root / "all_stats.json").exists():
                with open(dataset_root / "all_stats.json", "r") as f:
                    stats = json.load(f)
                # 如果存在domain_distribution并且包含当前客户端id
                if "domain_distribution" in stats and str(self.client_id) in stats["domain_distribution"]:
                    self.domain = stats["domain_distribution"][str(self.client_id)][0]  # 获取第一个领域名称
                    self.domain_info = f"[magenta]domain: {self.domain}[/magenta]\t"
        except Exception as e:
            pass
        
        
        # if(self.args.dataset.name == "multi_domain_digits"):
        #     self.dataset.set_transform(self.domain) 
            
        self.load_data_indices()

        if (
            package["optimizer_state"]
            and not self.args.common.reset_optimizer_on_global_epoch
        ):
            self.optimizer.load_state_dict(package["optimizer_state"])
        else:
            self.optimizer.load_state_dict(self.init_optimizer_state)

        if self.lr_scheduler is not None:
            if package["lr_scheduler_state"]:
                self.lr_scheduler.load_state_dict(package["lr_scheduler_state"])
            else:
                self.lr_scheduler.load_state_dict(self.init_lr_scheduler_state)

        self.model.load_state_dict(package["regular_model_params"], strict=False)
        self.model.load_state_dict(package["personal_model_params"], strict=False)
        if self.args.common.buffers == "drop":
            self.model.load_state_dict(self.init_buffers, strict=False)

        if self.return_diff:
            model_params = self.model.state_dict()
            self.regular_model_params = OrderedDict(
                (key, model_params[key].clone().cpu())
                for key in self.regular_params_name
            )
            
            


    def train(self, server_package: dict[str, Any]) -> dict:

        
        self.set_parameters(server_package)
        self.train_with_eval()
        client_package = self.package()
        return client_package

    def package(self):
        """Package data that client needs to transmit to the server. You can
        override this function and add more parameters.

        Returns:
            A dict: {
                `weight`: Client weight. Defaults to the size of client training set.
                `regular_model_params`: Client model parameters that will join parameter aggregation.
                `model_params_diff`: The parameter difference between the client trained and the global. `diff = global - trained`.
                `eval_results`: Client model evaluation results.
                `personal_model_params`: Client model parameters that absent to parameter aggregation.
                `optimzier_state`: Client optimizer's state dict.
                `lr_scheduler_state`: Client learning rate scheduler's state dict.
            }
        """
        model_params = self.model.state_dict()
        # print(f"model_params keys count: {len(model_params)}")
        # print(f"model_params keys: {list(model_params.keys())}")

        # # Print regular_params_name and count
        # print(f"regular_params_name count: {len(self.regular_params_name)}")
        # print(f"regular_params_name: {self.regular_params_name}")
        client_package = dict(
            weight=len(self.trainset),
            eval_results=self.eval_results,
            regular_model_params={
                key: model_params[key].clone().cpu() for key in self.regular_params_name
            },
            personal_model_params={
                key: model_params[key].clone().cpu()
                for key in self.personal_params_name
            },
            optimizer_state=deepcopy(self.optimizer.state_dict()),
            lr_scheduler_state=(
                {}
                if self.lr_scheduler is None
                else deepcopy(self.lr_scheduler.state_dict())
            ),
        )
        if self.return_diff:
            client_package["model_params_diff"] = {
                key: param_old - param_new
                for (key, param_new), param_old in zip(
                    client_package["regular_model_params"].items(),
                    self.regular_model_params.values(),
                )
            }
            client_package.pop("regular_model_params")
        return client_package

    def fit(self):
        self.model.train()
        self.dataset.train()
        
        # train_loader, test_loader=self.load_syn_dataset()
        
        # # Print model structure and parameters summary
        # print(f"\n=== Model Structure for Client {self.client_id} ===")
        # print(f"Domain: {self.domain}")
        # print(self.model)

        # # Get total number of parameters
        # total_params = sum(p.numel() for p in self.model.parameters())
        # trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

        # # Print parameter counts
        # print(f"Total parameters: {total_params:,}")
        # print(f"Trainable parameters: {trainable_params:,}")

        # # Create a unique model identifier using parameter shapes and first values
        # model_signature = []
        # for name, param in self.model.named_parameters():
        #     if param.requires_grad:
        #         # Record parameter shape and first few values as identifier
        #         first_values = param.data.flatten()[:3].tolist() if param.numel() > 0 else []
        #         model_signature.append(f"{name}: shape={tuple(param.shape)}, values={first_values}")

        # print(f"Model signature (for comparison):\n{model_signature[:5]}")  # Show first 5 signatures
        # print("=" * 50)
        
        # Display the first 100 images from the training set
        # import matplotlib.pyplot as plt

        # # Get up to 100 training images
        # train_images = []
        # train_labels = []
        # count = 0
        
        # for x, y in train_loader:#self.trainloader:
        #     batch_size = x.shape[0]
        #     for i in range(batch_size):
        #         if count >= 100:
        #             break
        #         train_images.append(x[i].cpu().numpy())
        #         train_labels.append(y[i].item())
        #         count += 1
        #     if count >= 100:
        #         break
        
        # if train_images:
        #     # Create a figure to display images
        #     num_images = len(train_images)
        #     rows = int(np.ceil(np.sqrt(num_images)))
        #     cols = int(np.ceil(num_images / rows))
            
        #     plt.figure(figsize=(15, 15))
            
        #     for i in range(num_images):
        #         plt.subplot(rows, cols, i + 1)
        #         img = train_images[i].transpose(1, 2, 0)  # CHW to HWC
                
            
        #         # Normalize to [0,1] range for display
        #         # img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        #         plt.imshow(img)
        #         plt.title(f"{train_labels[i]}", fontsize=8)
        #         plt.axis('off')
            
        #     plt.tight_layout()
        #     plt.suptitle(f"Client {self.client_id} Training Images" + 
        #                 (f" ({self.domain})" if self.domain else ""))
        #     plt.show()
        
        
        
        
        # # 打印trainloader的transform信息
        # print(f"Client {self.client_id} trainloader transform info:")
        # print(f"Dataset type: {type(self.dataset)}")
        
        # # 获取trainset的dataset属性（因为trainset是Subset）
        # if hasattr(self.trainset, 'dataset'):
        #     base_dataset = self.trainset.dataset
        #     print(f"Base dataset type: {type(base_dataset)}")
            
        #     # 打印transform
        #     if hasattr(base_dataset, 'data_transform'):
        #         print(f"Transform: {base_dataset.data_transform}")
        #     else:
        #         print("No transform attribute found")
        
        # print("-" * 50)
        
        ### 原数据集有问题      self.trainloader
        # print(f"Client {self.client_id} - local_epoch - {self.local_epoch}")
        # if (self.domain == "dslr"):
        for _ in range(self.local_epoch):
            # print(f"Client {self.client_id} - Epoch {_ + 1}/{self.local_epoch}, Domain: {self.domain}")
            for x, y in self.trainloader:
                # When the current batch size is 1, the batchNorm2d modules in the model would raise error.
                # So the latent size 1 data batches are discarded.
                
                if len(x) <= 1:
                    continue
                x, y = x.to(self.device), y.to(self.device)
                logit = self.model(x)
                loss = self.criterion(logit, y)
                self.optimizer.zero_grad()
                loss.backward()
                # print(f"loss: {loss.item():.4f}")
                # print(f"Client {self.client_id} - Epoch {_ + 1}/{self.local_epoch}, Loss: {loss.item():.4f}, Batch size: {len(x)}")
                self.optimizer.step()
            
            

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
            

    @torch.no_grad()
    def evaluate(self, model: torch.nn.Module = None) -> dict[str, Metrics]:
        """Evaluating client model.

        Args:
            model: Used model. Defaults to None, which will fallback to `self.model`.

        Returns:
            A evalution results dict: {
                `train`: results on client training set.
                `val`: results on client validation set.
                `test`: results on client test set.
            }
        """
        target_model = self.model if model is None else model
        target_model.eval()
        self.dataset.eval()
        train_metrics = Metrics()
        val_metrics = Metrics()
        test_metrics = Metrics()
        criterion = torch.nn.CrossEntropyLoss(reduction="sum")

        if (
            len(self.testset) > 0
            and (self.testing or self.args.common.client_side_evaluation)
            and self.args.common.test.client.test
        ):
            test_metrics = evaluate_model(
                model=target_model,
                dataloader=self.testloader,
                criterion=criterion,
                device=self.device,
            )

        if (
            len(self.valset) > 0
            and (self.testing or self.args.common.client_side_evaluation)
            and self.args.common.test.client.val
        ):
            val_metrics = evaluate_model(
                model=target_model,
                dataloader=self.valloader,
                criterion=criterion,
                device=self.device,
            )

        if (
            len(self.trainset) > 0
            and (self.testing or self.args.common.client_side_evaluation)
            and self.args.common.test.client.train
        ):
            train_metrics = evaluate_model(
                model=target_model,
                dataloader=self.trainloader,
                criterion=criterion,
                device=self.device,
            )
            

        return {"train": train_metrics, "val": val_metrics, "test": test_metrics}

    
    def test(self, server_package: dict[str, Any]) -> dict[str, dict[str, Metrics]]:
        """Test client model. If `finetune_epoch > 0`, `finetune()` will be
        activated.

        Args:
            server_package: Parameter package.

        Returns:
            A model evaluation results dict : {
                `before`: {...}
                `after`: {...}
                `message`: "..."
            }
            `before` means pre-local-training.
            `after` means post-local-training
        """
        self.testing = True
        self.set_parameters(server_package)

        results = {
            "before": {"train": Metrics(), "val": Metrics(), "test": Metrics()},
            "after": {"train": Metrics(), "val": Metrics(), "test": Metrics()},
        }

        results["before"] = self.evaluate()
        if self.args.common.test.client.finetune_epoch > 0:
            frz_params_dict = deepcopy(self.model.state_dict())
            self.finetune()
            results["after"] = self.evaluate()
            self.model.load_state_dict(frz_params_dict)

        self.testing = False
        return results

    def finetune(self):
        """Client model finetuning.

        This function will only be activated in `test()`
        """
        self.model.train()
        self.dataset.train()
        for _ in range(self.args.common.test.client.finetune_epoch):
            for x, y in self.trainloader:
                if len(x) <= 1:
                    continue

                x, y = x.to(self.device), y.to(self.device)
                logit = self.model(x)
                loss = self.criterion(logit, y)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
