import copy
import datetime
import numpy as np
import sys
import time
import torch
import torch.nn.functional as F
import models
from itertools import compress
from collections import deque # 使用deque来高效管理历史记录窗口
from config import cfg
from data import make_data_loader, make_batchnorm_stats, FixTransform, MixDataset
from utils import to_device, make_optimizer, collate, to_device
from metrics import Accuracy

def _calculate_prototypes_from_loader(model, loader):
    """A helper function to calculate prototypes given a model and a data loader."""
    if not loader: return {}
    model.train(False)
    all_features = []
    all_labels = []
    with torch.no_grad():
        for i, input in enumerate(loader):
            input = collate(input)
            input = to_device(input, cfg['device'])
            # We assume the loader provides 'data' for feature extraction
            features, _ = model.features_and_logits(input['data'])
            all_features.append(features)
            all_labels.append(input['target'])
    
    if not all_features: return {}

    all_features = torch.cat(all_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    prototypes = {}
    for label_val in all_labels.unique():
        label_val = label_val.item()
        mask = (all_labels == label_val)
        if mask.any():
            prototypes[label_val] = all_features[mask].mean(dim=0)
    return prototypes

class Server:
    def __init__(self, model):
        self.model_state_dict = save_model_state_dict(model.state_dict())
        if 'fmatch' in cfg['loss_mode']:
            optimizer = make_optimizer(model.make_sigma_parameters(), 'local')
            global_optimizer = make_optimizer(model.make_phi_parameters(), 'global')
        else:
            optimizer = make_optimizer(model.parameters(), 'local')
            global_optimizer = make_optimizer(model.parameters(), 'global')
        self.optimizer_state_dict = save_optimizer_state_dict(optimizer.state_dict())
        self.global_optimizer_state_dict = save_optimizer_state_dict(global_optimizer.state_dict())

        self.stage = 1
        self.gpc_history = [] # History of Global Prototype Consistency scores
        self.alpha_decay_rate = cfg.get('alpha_decay_rate', 0.995) # For Stage 2 adaptive weight
        self.current_alpha = cfg.get('server_proto_weight_initial', 0.7)
        self.threshold = 0
        self.prototypes = {}
        self.server_own_prototypes = {}
        self.fss_history = deque(maxlen=cfg['fss_history_window'])
        #self.gppa_history = deque(maxlen=cfg['gppa_history_window'])
        self.gcplca_history = deque(maxlen=cfg['gcplca_history_window'])

    def distribute(self, client, batchnorm_dataset=None):
        model = eval('models.{}().to(cfg["device"])'.format(cfg['model_name']))
        model.load_state_dict(self.model_state_dict)
        if batchnorm_dataset is not None:
            model = make_batchnorm_stats(batchnorm_dataset, model, 'global')
        model_state_dict = save_model_state_dict(model.state_dict())
        for m in range(len(client)):
            if client[m].active:
                client[m].model_state_dict = copy.deepcopy(model_state_dict)
                client[m].threshold = self.threshold
                if self.stage >= 2:
                    client[m].server_prototypes = copy.deepcopy(self.prototypes)
                client[m].stage = self.stage
        return

    def update_model(self, client):
        if 'fmatch' not in cfg['loss_mode']:
            with torch.no_grad():
                valid_client = [client[i] for i in range(len(client)) if client[i].active]
                if len(valid_client) > 0:
                    model = eval('models.{}()'.format(cfg['model_name']))
                    model.load_state_dict(self.model_state_dict)
                    global_optimizer = make_optimizer(model.parameters(), 'global')
                    global_optimizer.load_state_dict(self.global_optimizer_state_dict)
                    global_optimizer.zero_grad()
                    weight = torch.ones(len(valid_client))
                    weight = weight / weight.sum()
                    for k, v in model.named_parameters():
                        parameter_type = k.split('.')[-1]
                        if 'weight' in parameter_type or 'bias' in parameter_type:
                            tmp_v = v.data.new_zeros(v.size())
                            for m in range(len(valid_client)):
                                tmp_v += weight[m] * valid_client[m].model_state_dict[k]
                            v.grad = (v.data - tmp_v).detach()
                    global_optimizer.step()
                    self.global_optimizer_state_dict = save_optimizer_state_dict(global_optimizer.state_dict())
                    self.model_state_dict = save_model_state_dict(model.state_dict())
        else:
            raise ValueError('Not valid loss mode')
        return
    
    def update_prototypes(self, client):
        """
        Checks for stage transition and updates the global prototypes.
        This should be called AFTER server.train() to use the latest server knowledge.
        """
        valid_client = [client[i] for i in range(len(client)) if client[i].active]
        if len(valid_client) > 0:
            self.check_stage_transition(valid_client)
            current_round_consensus = {}
            proto_momentum = cfg['proto_momentum']

            if self.stage < 3:
                print(f"Stage {self.stage}: Updating global prototypes with the latest server's knowledge.")
                current_round_consensus = self.server_own_prototypes

            else:
                print(f"Stage 3: Aggregating latest server and client prototypes...")
                client_protos_by_label = {}
                client_data_counts = {}
                total_data_count = 0
                for c in valid_client:
                    if hasattr(c, 'local_prototypes') and c.local_prototypes:
                        count = c.num_samples_for_proto
                        if count > 0:
                            client_data_counts[c.client_id] = count
                            total_data_count += count
                            for label, proto in c.local_prototypes.items():
                                if label not in client_protos_by_label:
                                    client_protos_by_label[label] = []
                                client_protos_by_label[label].append({'proto': proto.detach(), 'client_id': c.client_id})

                avg_client_prototypes = {}
                if client_protos_by_label and total_data_count > 0:
                    for label, proto_info_list in client_protos_by_label.items():
                        weighted_protos = torch.stack([
                            info['proto'] * (client_data_counts[info['client_id']] / total_data_count)
                            for info in proto_info_list
                        ])
                        avg_client_prototypes[label] = torch.sum(weighted_protos, dim=0)
                all_labels = set(self.server_own_prototypes.keys()) | set(avg_client_prototypes.keys())
                for label in all_labels:
                    server_part = self.server_own_prototypes.get(label)
                    client_part = avg_client_prototypes.get(label)
                    
                    if server_part is not None and client_part is not None:
                        current_round_consensus[label] = self.current_alpha * server_part + (1 - self.current_alpha) * client_part
                    elif server_part is not None:
                        current_round_consensus[label] = server_part
                    else: # client_part is not None
                        current_round_consensus[label] = client_part
                self.current_alpha = max(self.current_alpha * self.alpha_decay_rate, cfg.get('server_proto_weight_final', 0.1))
  
            for label, proto in current_round_consensus.items():
                if not torch.isnan(proto).any():
                    if label in self.prototypes:
                        self.prototypes[label] = proto_momentum * self.prototypes[label] + (1 - proto_momentum) * proto
                    else:
                        self.prototypes[label] = proto
        for i in range(len(client)):
            client[i].active = False

    def calculate_fss(self, features, labels, prototypes):
        if len(prototypes) < 2:
            self.fss_history.append(0)
            return

        unique_labels = sorted(prototypes.keys())
        proto_tensor = torch.stack([prototypes[k] for k in unique_labels])
        inter_dist = torch.cdist(proto_tensor, proto_tensor, p=2)
        inter_dist.fill_diagonal_(float('inf')) 
        min_inter_dist = inter_dist.min(dim=1).values
        intra_dist_dict = {}
        for label_val in unique_labels:
            mask = (labels == label_val)
            class_features = features[mask]
            dist_to_proto = torch.norm(class_features - prototypes[label_val], p=2, dim=1)
            intra_dist_dict[label_val] = dist_to_proto.mean().item()
        
        intra_dist = torch.tensor([intra_dist_dict[k] for k in unique_labels], device=features.device)
        fss_ratios = min_inter_dist / (intra_dist + 1e-8)
        fss_score = fss_ratios.mean().item()
        
        self.fss_history.append(fss_score)
        print(f"FSS Score this round: {fss_score:.4f}, History (mean): {np.mean(self.fss_history):.4f}, (std): {np.std(self.fss_history):.4f}")

    def train(self, dataset, lr, metric, logger):
        data_loader = make_data_loader({'train': dataset}, 'server')['train']
        model = eval('models.{}().to(cfg["device"])'.format(cfg['model_name']))
        model.load_state_dict(self.model_state_dict)
        self.optimizer_state_dict['param_groups'][0]['lr'] = lr
        optimizer = make_optimizer(model.parameters(), 'local')
        optimizer.load_state_dict(self.optimizer_state_dict)
        model.train(True)
        if cfg['server']['num_epochs'] == 1:
            num_batches = int(np.ceil(len(data_loader) * float(cfg['local_epoch'][1])))
        else:
            num_batches = None
        threshold = []
        max_threshold = 0.0
        for epoch in range(1, cfg['server']['num_epochs'] + 1):
            for i, input in enumerate(data_loader):
                input = collate(input)
                input_size = input['data'].size(0)
                input = to_device(input, cfg['device'])
                optimizer.zero_grad()
                output = model(input)
                after_softmax = F.softmax(output['target'], dim=1) 
                avg_threshold = after_softmax.max(dim=1)[0].mean().item()
                max_threshold = max(avg_threshold, max_threshold)

                # In Stage 3, server also learns from the global consensus prototype
                if self.stage == 3 and self.prototypes:
                    proto_loss = 0
                    features = output['features']
                    labels = input['target']
                    proto_losses = []
                    for j in range(len(labels)):
                        label = labels[j].item()
                        if label in self.prototypes:
                            proto_losses.append(F.mse_loss(features[j], self.prototypes[label]))
                    if proto_losses:
                        proto_loss = torch.mean(torch.stack(proto_losses))
                    output['loss'] += proto_loss * cfg.get('proto_lambda_server', 1.0)
                    #output['loss'] += proto_loss * cfg['proto_lambda_server']

                output['loss'].backward()

                torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                optimizer.step()
                evaluation = metric.evaluate(['Loss', 'Accuracy'], input, output)
                logger.append(evaluation, 'train', n=input_size)
                if num_batches is not None and i == num_batches - 1:
                    break
                threshold.append(avg_threshold)
        self.threshold = sum(threshold)/len(threshold)
        self.threshold += (max_threshold - self.threshold)/2
        self.threshold = min(self.threshold, 0.95)
        self.optimizer_state_dict = save_optimizer_state_dict(optimizer.state_dict())
        self.model_state_dict = save_model_state_dict(model.state_dict())
        model.train(False)
        self.server_own_prototypes = {}
        class_counts = {}
        all_features_for_fss = []
        all_labels_for_fss = []
        with torch.no_grad():
            # Use the same labeled data_loader as for training
            for i, input in enumerate(data_loader):
                input = collate(input)
                input = to_device(input, cfg['device'])
                output = model(input)
                features = output['features']
                labels = input['target']
                all_features_for_fss.append(features.cpu())
                all_labels_for_fss.append(labels.cpu())
                for j in range(len(labels)):
                    label = labels[j].item()
                    if label not in self.server_own_prototypes:
                        self.server_own_prototypes[label] = features[j].clone()
                        class_counts[label] = 1
                    else:
                        self.server_own_prototypes[label] += features[j]
                        class_counts[label] += 1
        
        # Average the features to get the final prototype
        for label in self.server_own_prototypes:
            self.server_own_prototypes[label] /= class_counts[label]

        if self.stage == 1:
            final_features = torch.cat(all_features_for_fss, dim=0).to(cfg['device'])
            final_labels = torch.cat(all_labels_for_fss, dim=0).to(cfg['device'])
            self.calculate_fss(final_features, final_labels, self.server_own_prototypes)
        
        print(f"Done. Generated {len(self.server_own_prototypes)} prototypes.")
        return

    def check_stage_transition(self, valid_client):
        if self.stage == 1:
            if len(self.fss_history) == self.fss_history.maxlen:
                mean_fss = np.mean(self.fss_history)
                std_fss = np.std(self.fss_history)
                if mean_fss > cfg['fss_mean_threshold'] and std_fss < cfg['fss_std_threshold']:
                    self.stage = 2
                    print("\n" + "="*80)
                    print(f"STAGE TRANSITION: FSS stable and high (mean={mean_fss:.2f}, std={std_fss:.2f}). Moving to STAGE 2: One-way Guidance.")
                    print("="*80 + "\n")
            return
        if self.stage == 2:
            all_cplca_scores = [c.cplca_score for c in valid_client if hasattr(c, 'cplca_score') and c.cplca_score is not None]
            if all_cplca_scores:
                gcplca_score = np.mean(all_cplca_scores)
                self.gcplca_history.append(gcplca_score)
                print(f"GCPLCA Score this round: {gcplca_score:.4f}, History (mean): {np.mean(self.gcplca_history):.4f}, (std): {np.std(self.gcplca_history):.4f}")

                if len(self.gcplca_history) == self.gcplca_history.maxlen:
                    mean_gcplca = np.mean(self.gcplca_history)
                    std_gcplca = np.std(self.gcplca_history)
                    if mean_gcplca > cfg['gcplca_mean_threshold'] and std_gcplca < cfg['gcplca_std_threshold']:
                        self.stage = 3
                        print("\n" + "="*80)
                        print(f"STAGE TRANSITION: GCPLCA stable and high (mean={mean_gcplca:.2f}, std={std_gcplca:.2f}). Moving to STAGE 3.")
                        print("="*80 + "\n")

class Client:
    def __init__(self, client_id, model, data_split):
        self.client_id = client_id
        self.data_split = data_split
        self.model_state_dict = save_model_state_dict(model.state_dict())
        if 'fmatch' in cfg['loss_mode']:
            optimizer = make_optimizer(model.make_phi_parameters(), 'local')
        else:
            optimizer = make_optimizer(model.parameters(), 'local')
        self.optimizer_state_dict = save_optimizer_state_dict(optimizer.state_dict())
        self.active = False
        self.beta = torch.distributions.beta.Beta(torch.tensor([cfg['alpha']]), torch.tensor([cfg['alpha']]))
        self.verbose = cfg['verbose']
        self.server = None 
        self.server_prototypes = {}
        self.local_prototypes = {}
        self.threshold = 1
        self.num_samples_for_model = 0
        self.num_samples_for_proto = 0
        self.num_samples_for_ppa = 0
        self.cplca_score = None
        self.stage = 0

    def make_hard_pseudo_label(self, soft_pseudo_label, threshold):
        max_p, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1)
        mask = max_p.ge(threshold)
        return hard_pseudo_label, mask

    def make_dataset(self, dataset, metric, logger):
        if 'sup' in cfg['loss_mode']:
            return dataset
        elif 'fix' in cfg['loss_mode']:
            with torch.no_grad():
                data_loader = make_data_loader({'train': dataset}, 'global', shuffle={'train': False})['train']
                model = eval('models.{}(track=True).to(cfg["device"])'.format(cfg['model_name']))
                model.load_state_dict(self.model_state_dict)
                model.train(False)
                output = []
                target = []
                for i, input in enumerate(data_loader):
                    input = collate(input)
                    input = to_device(input, cfg['device'])
                    output_ = model(input)
                    output_i = output_['target']
                    target_i = input['target']
                    output.append(output_i.cpu())
                    target.append(target_i.cpu())
                output_, input_ = {}, {}
                output_['target'] = torch.cat(output, dim=0)
                input_['target'] = torch.cat(target, dim=0)
                output_['target'] = F.softmax(output_['target'], dim=-1)

                max_p, new_target = torch.max(output_['target'], dim=-1)
                self.cplca_score = max_p.mean().item()

                new_target, mask = self.make_hard_pseudo_label(output_['target'], self.threshold)
                output_['mask'] = mask
                evaluation = metric.evaluate(['PAccuracy', 'MAccuracy', 'LabelRatio'], input_, output_)
                logger.append(evaluation, 'train', n=len(input_['target']))
                if torch.any(mask):
                    fix_dataset = copy.deepcopy(dataset)
                    fix_dataset.target = new_target.tolist()
                    mask = mask.tolist()
                    fix_dataset.data = list(compress(fix_dataset.data, mask))
                    fix_dataset.target = list(compress(fix_dataset.target, mask))
                    fix_dataset.other = {'id': list(range(len(fix_dataset.data)))}
                    if 'mix' in cfg['loss_mode']:
                        mix_dataset = copy.deepcopy(dataset)
                        mix_dataset.target = new_target.tolist()
                        mix_dataset = MixDataset(len(fix_dataset), mix_dataset)
                    else:
                        mix_dataset = None
                    return fix_dataset, mix_dataset
                else:
                    return None
        else:
            raise ValueError('Not valid client loss mode')

    def train(self, dataset, lr, metric, logger):
        if 'fix' in cfg['loss_mode'] and 'mix' in cfg['loss_mode'] and 'batch' not in cfg[
            'loss_mode'] and 'frgd' not in cfg['loss_mode'] and 'fmatch' not in cfg['loss_mode']:
            fix_dataset, mix_dataset = dataset
            self.num_samples_for_model = len(fix_dataset.data) if fix_dataset else 0
            
            fix_data_loader = make_data_loader({'train': fix_dataset}, 'client')['train']
            mix_data_loader = make_data_loader({'train': mix_dataset}, 'client')['train']
            model = eval('models.{}().to(cfg["device"])'.format(cfg['model_name']))
            model.load_state_dict(self.model_state_dict, strict=False)
            self.optimizer_state_dict['param_groups'][0]['lr'] = lr
            optimizer = make_optimizer(model.parameters(), 'local')
            optimizer.load_state_dict(self.optimizer_state_dict)
            model.train(True)
            if cfg['client']['num_epochs'] == 1:
                num_batches = int(np.ceil(len(fix_data_loader) * float(cfg['local_epoch'][0])))
            else:
                num_batches = None

            collected_features = []
            collected_labels = []

            for epoch in range(1, cfg['client']['num_epochs'] + 1):
                for i, (fix_input, mix_input) in enumerate(zip(fix_data_loader, mix_data_loader)):
                    input = {'data': fix_input['data'], 'target': fix_input['target'], 'aug': fix_input['aug'],
                             'mix_data': mix_input['data'], 'mix_target': mix_input['target']}
                    input = collate(input)
                    input_size = input['data'].size(0)
                    input['lam'] = self.beta.sample()[0]
                    input['mix_data'] = (input['lam'] * input['data'] + (1 - input['lam']) * input['mix_data']).detach()
                    input['mix_target'] = torch.stack([input['target'], input['mix_target']], dim=-1)
                    input['loss_mode'] = cfg['loss_mode']
                    input = to_device(input, cfg['device'])
                    optimizer.zero_grad()
                    output = model(input)

                    if self.stage >= 2 and self.server_prototypes:
                        features = output['features']
                        pseudo_labels = input['target']
                        
                        proto_losses = []
                        for j in range(len(pseudo_labels)):
                            label = pseudo_labels[j].item()
                            if label in self.server_prototypes:
                                proto_losses.append(F.mse_loss(features[j], self.server_prototypes[label]))
                        if proto_losses:
                            proto_loss_val = torch.mean(torch.stack(proto_losses))
                            output['loss'] += proto_loss_val * cfg['proto_lambda_client']

                    output['loss'].backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                    optimizer.step()
                    evaluation = metric.evaluate(['Loss', 'Accuracy'], input, output)
                    logger.append(evaluation, 'train', n=input_size)
                    if self.stage == 3:
                        with torch.no_grad():
                            high_conf_threshold = min(self.threshold + 0.05, 0.99) 
                            logits = output['target']
                            probs = F.softmax(logits, dim=1)
                            max_p, _ = torch.max(probs, dim=-1)
                            mask = max_p.ge(high_conf_threshold)

                            if mask.any():
                                collected_features.append(output['features'][mask].detach().cpu())
                                collected_labels.append(input['target'][mask].detach().cpu())
                    if num_batches is not None and i == num_batches - 1:
                        break
        else:
            raise ValueError('Not valid client loss mode')
        self.optimizer_state_dict = save_optimizer_state_dict(optimizer.state_dict())
        self.model_state_dict = save_model_state_dict(model.state_dict())
        if self.stage == 3 and collected_features:
            all_features = torch.cat(collected_features, dim=0)
            all_labels = torch.cat(collected_labels, dim=0)
            
            self.local_prototypes = {}
            for label_val in all_labels.unique():
                label_val = label_val.item()
                mask = (all_labels == label_val)
                self.local_prototypes[label_val] = all_features[mask].mean(dim=0).to(cfg['device'])
            
            self.num_samples_for_proto = len(all_features)
        else:
            self.local_prototypes = {}
            self.num_samples_for_proto = 0
        return


def save_model_state_dict(model_state_dict):
    return {k: v.cpu() for k, v in model_state_dict.items()}


def save_optimizer_state_dict(optimizer_state_dict):
    optimizer_state_dict_ = {}
    for k, v in optimizer_state_dict.items():
        if k == 'state':
            optimizer_state_dict_[k] = to_device(optimizer_state_dict[k], 'cpu')
        else:
            optimizer_state_dict_[k] = copy.deepcopy(optimizer_state_dict[k])
    return optimizer_state_dict_
