import yaml  
import torch
import argparse
import numpy as np
import os,json,random
from tqdm import trange
from collections import defaultdict
from HyperNetworks import ConvNetHyper, Hyper
from display import show_plot,ExperimentLogger

from torchvision import datasets, transforms  
from torch.utils.data import Dataset, DataLoader  
from utils import get_args, get_network, get_time, DiffAugment, ParamDiffAug, set_seed,get_universum

from dataset import get_classes, dirichlet_distribution

class CompressedDataset(Dataset):  
    def __init__(self, load_path):
        dataset = torch.load(load_path,weights_only=False) 
        data = dataset['data']
        
        X = [item[0] for item in data] 
        labels = [item[1] for item in data] 

        self.X = torch.cat(X, dim=0)  
        self.labels = torch.cat(labels, dim=0)  

        print(self.labels.shape)

        self.transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  
            transforms.RandomHorizontalFlip(),    
        ])
    def __len__(self):  
        return len(self.labels)  
    def __getitem__(self, idx):  
        return  self.transform(self.X[idx]), self.labels[idx]  

def subtract_(W, W_old):
    dW = {key : None for key in W_old}
    for name in W_old:
        dW[name] = W[name]-W_old[name]
    return dW
from loss_fn import Distance_loss
class LocalTrainer:
    def __init__(self, args, nets, nets_name, device):
        if args.data_distribution == 'incomplete_label':
            self.dst_train,self.mean, self.std, self.global_test_loader= get_classes(
                args.data_name, args.data_path, args.num_nodes, args.classes_per_node, args.seed)
        else:
            self.dst_train,self.mean, self.std, self.global_test_loader= dirichlet_distribution(
                args.data_name, args.data_path, args.num_nodes, args.seed, args.least_nums,args.alpha)
        
       
        self.train_loaders, self.test_loaders, self.data_size = [],[],[]  
        
        for _, data in enumerate(self.dst_train):
            n_test = int(len(data)*0.25)
            n_train = len(data) -  n_test
            train_data, test_data = torch.utils.data.random_split(data, [n_train, n_test])
            train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
            test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
            self.train_loaders.append(train_loader)
            self.test_loaders.append(test_loader)
            self.data_size.append(len(train_data))

        self.device = device
        self.args = args
        self.nets = nets
        self.criteria = torch.nn.CrossEntropyLoss()

        with open('config/All_Layers.json', 'r') as f:
            self.layer_list = json.load(f) 
        self.personalized_layer_name = nets_name
    
    def train(self, weights, client_id, distill_loader=None):
        client_weight = {}
        idx = 0
        for key,value in self.nets[client_id].named_parameters():
            if key in self.layer_list[self.personalized_layer_name[client_id]]:
                client_weight[key]=weights[idx:idx+value.numel()].reshape_as(value)
                idx += value.numel()
        self.nets[client_id].load_state_dict(client_weight, strict=False)
        self.nets[client_id].train()
        optimizer = torch.optim.SGD(self.nets[client_id].parameters(), lr=self.args.inner_lr, momentum=.9, weight_decay=self.args.inner_wd)

        distance_loss = Distance_loss('UniConLoss',device=self.device)
        distill_iter = iter(distill_loader) if distill_loader is not None else None
    
        for _ in range(self.args.inner_steps):
            for x, y in self.train_loaders[client_id]: 
                x, y = x.to(device), y.to(device)

                features, local_pred = self.nets[client_id](x)
                local_loss = self.criteria(local_pred, y)

                distill_loss = 0.0
                loss_reg =0.0
                if distill_iter is not None:
                    try:
                        dx, dy = next(distill_iter)
                    except StopIteration:
                        distill_iter = iter(distill_loader)
                        dx, dy = next(distill_iter)
                    dx, dy = dx.to(self.device), dy.to(self.device)
                    reg_feature, d_pred = self.nets[client_id](dx)
                    distill_loss += torch.nn.CrossEntropyLoss()(d_pred, dy)

                    universum = get_universum(x, dx, y, dy, self.args,self.device)
                    uni_features = self.nets[client_id].get_features(universum)
                    uni_features= torch.cat([uni_features, reg_feature.detach()], dim=0)

                    loss_reg += distance_loss(features, reg_feature.detach(), y, dy,uni_features)

                if distill_iter is not None:
                    loss = local_loss + 0.1 * distill_loss + 0.1 * loss_reg
                else:
                    loss = local_loss 

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.nets[client_id].parameters(), 50)
                optimizer.step()

        final_state = self.nets[client_id].state_dict()
        final_state_weight = torch.cat([(final_state[k]).view(-1) for k in  self.layer_list[self.personalized_layer_name[client_id]]],dim=0)
        dW = weights - final_state_weight

        if self.args.topk:
            with torch.no_grad(): 
                _, index = torch.topk(torch.abs(dW), int(len(dW)*0.3))
                values = dW[index].view(index.shape)  
                sparse_tensor = torch.sparse_coo_tensor(index.unsqueeze(0), values, size=(dW.numel(),), dtype=torch.float32)
                return sparse_tensor
        return dW
    
    @torch.no_grad()
    def evalute(self, weights, client_id):
        running_loss, running_correct, running_samples = 0., 0., 0.
        client_weights = {}
        idx = 0
        for key,value in self.nets[client_id].named_parameters():
            if key in self.layer_list[self.personalized_layer_name[client_id]]:
                client_weights[key]=weights[idx:idx+value.numel()].reshape_as(value)
                idx += value.numel()
        self.nets[client_id].load_state_dict(client_weights,strict=False)
        self.nets[client_id].eval()

        for x, y in trainer.test_loaders[client_id]:
            x = x.to(self.device)
            y = y.to(self.device)
            _, pred = self.nets[client_id](x)
            running_loss += self.criteria(pred, y).item()
            running_correct += pred.argmax(1).eq(y).sum().item()
            running_samples += len(y)
        
        global_running_loss, global_running_correct, global_running_samples = 0., 0., 0.
        for x, y in trainer.global_test_loader:
            x = x.to(self.device)
            y = y.to(self.device)
            _, pred = self.nets[client_id](x)
            global_running_loss += self.criteria(pred, y).item()
            global_running_correct += pred.argmax(1).eq(y).sum().item()
            global_running_samples += len(y)

        return running_loss/(len(trainer.test_loaders[client_id]) + 1), running_correct, running_samples,\
        global_running_loss/(len(trainer.global_test_loader) + 1), global_running_correct, global_running_samples
    
def evaluate(hnet, trainer, clients):
    results = defaultdict(lambda: defaultdict(list))
    global_results = defaultdict(lambda: defaultdict(list))
    for client_id in clients:
        hnet.eval()
        weights = hnet(client_id)
        running_loss, running_correct, running_samples,\
        global_running_loss, global_running_correct, global_running_samples = trainer.evalute(weights, client_id)
        results[client_id]['loss'],    global_results[client_id]['loss']    = running_loss, global_running_loss
        results[client_id]['correct'], global_results[client_id]['correct'] = running_correct,global_running_correct
        results[client_id]['total'],   global_results[client_id]['total']   = running_samples,global_running_samples
    
    total_correct,global_total_correct = sum([val['correct'] for val in results.values()]),sum([val['correct'] for val in global_results.values()])
    total_samples,global_total_samples = sum([val['total'] for val in results.values()]),sum([val['total'] for val in global_results.values()])
    avg_loss,global_avg_loss = np.mean([val['loss'] for val in results.values()]),np.mean([val['loss'] for val in global_results.values()])

    avg_acc,global_avg_acc = total_correct / total_samples, global_total_correct / global_total_samples
    all_acc,global_all_acc = [val['correct'] / val['total'] for val in results.values()],[val['correct'] / val['total'] for val in global_results.values()]
    all_loss,global_all_loss = [val['loss'] for val in results.values()],[val['loss'] for val in global_results.values()]
    return avg_loss, avg_acc, all_acc, all_loss,global_avg_loss, global_avg_acc, global_all_acc, global_all_loss


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = get_args(parser)
    set_seed(args.seed)
    if args.cuda == -1:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:%d" % args.cuda)
    if args.data_name == 'cifar100':
        args.classes_per_node = 10
        out_dim = 100
        in_channels=3
    elif args.data_name ==  'tiny_imageNet':
        args.classes_per_node = 20
        out_dim = 200
        in_channels=3
    elif args.data_name == 'cifar10':
        args.classes_per_node = 2  
        out_dim = 10    
        in_channels=3
    elif args.data_name == 'emnist':
        args.classes_per_node = 6
        out_dim = 62    
        in_channels=1   
    else:
        raise ValueError("choose data_name from ['emnist', 'cifar10', 'tiny_imageNet' ,'cifar100']")
    
    if args.train_clients == -1:
        train_list = range(args.num_nodes)
    else:
        train_list = range(args.train_clients)

    if not os.path.exists(args.output_path):  
        os.makedirs(args.output_path)  
    if args.homogeneous:
        save_title = 'cnn_'+args.data_name+'_'+args.data_distribution+'_cn_'+str(args.num_nodes)+'_tr_'+str(len(train_list))
    else:
        save_title = 'more_model_'+args.data_name+'_'+args.data_distribution+'_cn_'+str(args.num_nodes)+'_tr_'+str(len(train_list))
   
    file_path = os.path.join(args.output_path, save_title + '.json') 
    if not os.path.exists(file_path):  
        with open(file_path, 'w') as json_file:  
            json_file.write('[]\n')
    
       
    global_file_path = os.path.join(args.output_path, save_title + '_global.json') 
    if not os.path.exists(global_file_path):  
        with open(global_file_path, 'w') as json_file:  
            json_file.write('[]\n')
    
    nets, nets_name, nets_param_nums = [],[],[]

    with open('config/defaults.yaml', 'r') as file:  
        config = yaml.safe_load(file) 

    if args.homogeneous:
        for _ in train_list:
            nets.append(get_network('ConvNet',in_channels,out_dim).to(device))
            nets_name.append("ConvNet")
            nets_param_nums.append(config[args.data_name]["ConvNet"])
          
        hnet = ConvNetHyper(nets_param_nums, args.embed_dim, args.hnet_output_size,  args.hidden_layers, args.hidm, norm_var=args.norm_var).to(device)
    else:
        model_nums = 5
        basic_value = len(train_list) // model_nums
        remainder = len(train_list)  % model_nums
        result = [basic_value] * model_nums
        for i in range(remainder):  
            result[i] += 1  

        for _ in range(result[0]):
            nets.append(get_network('ConvNet',in_channels,out_dim).to(device))
            nets_name.append("ConvNet")
            nets_param_nums.append(config[args.data_name]["ConvNet"])
           
        for _ in range(result[1]):
            nets.append(get_network('LeNet',in_channels,out_dim).to(device))
            nets_name.append("LeNet")
            nets_param_nums.append(config[args.data_name]["LeNet"])
          
        for _ in range(result[2]):
            nets.append(get_network('VGG8',in_channels,out_dim).to(device))
            nets_name.append("VGG8")
            nets_param_nums.append(config[args.data_name]["VGG8"])
           
        for _ in range(result[3]):
            nets.append(get_network('MLP',in_channels,out_dim).to(device))
            nets_name.append("MLP")
            nets_param_nums.append(config[args.data_name]["MLP"])
           
        for _ in range(result[4]):
            nets.append(get_network('ResNet9',in_channels,out_dim).to(device))
            nets_name.append("ResNet9")
            nets_param_nums.append(config[args.data_name]["ResNet9"])
            
        hnet = Hyper(nets_param_nums, args.embed_dim, args.hnet_output_size,  args.hidden_layers, args.hidm, norm_var=args.norm_var).to(device)
    
    trainer = LocalTrainer(args, nets, nets_name, device)

    hnet.train()
    hnet_optim = {
        'adam': torch.optim.Adam(params=hnet.parameters(), lr=args.lr),
        'adamw': torch.optim.AdamW(params=hnet.parameters(), lr=args.lr)
    }[args.optim]

    best_acc = -1
    global_best_acc = -1
    client_stats = ExperimentLogger()
    global_stats = ExperimentLogger()


    load_path = args.distill_data_dir
    
    mixup_dataset = CompressedDataset(load_path)
    mixup_loader = DataLoader(mixup_dataset, batch_size=50, shuffle=True)  

    for step in trange(args.num_steps): 
        idc = random.sample(list(train_list), len(train_list))         

        for client_id in idc:
            weights = hnet(client_id)
            delta = trainer.train(weights, client_id, mixup_loader)
            
            if args.topk:
                delta = delta.to_dense()

            hnet_optim.zero_grad()
            hnet_grads = torch.autograd.grad(weights, hnet.parameters(), grad_outputs=delta, allow_unused=True)
            for p, g in zip(hnet.parameters(), hnet_grads):
                p.grad = g
            torch.nn.utils.clip_grad_norm_(hnet.parameters(), args.grad_clip)
            hnet_optim.step()

        avg_loss, avg_acc, all_acc, all_loss, global_avg_loss, global_avg_acc, global_all_acc, global_all_loss = evaluate(hnet, trainer, train_list)
        print(f"Step: {step+1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f}")
        print(f"Step: {step+1}, Global AVG Loss: {global_avg_loss:.4f}, Global AVG Acc: {global_avg_acc:.4f}")

        client_stats.log({"rounds": step,'client_acc':all_acc,'client_loss':all_loss})
        global_stats.log({"rounds": step,'client_acc':global_all_acc,'client_loss':global_all_loss})

        if(avg_acc>best_acc): 
            best_acc = avg_acc
            result = {"Rounds":step, "best_average": best_acc}  
            with open(file_path, 'r+') as json_file:  
                data = json.load(json_file)  
                data.append(result) 
                json_file.seek(0)  
                json.dump(data, json_file, indent=4)  
        
        if(global_avg_acc>global_best_acc): 
            global_best_acc = global_avg_acc
            global_result = {"Rounds":step, "best_average": global_best_acc}  
            with open(global_file_path, 'r+') as json_file:  
                data = json.load(json_file)  
                data.append(global_result) 
                json_file.seek(0)  
                json.dump(data, json_file, indent=4)  

        png_path = os.path.join(args.output_path, save_title + '.png') 
        show_plot(client_stats, args.num_steps, png_path) 

        global_png_path = os.path.join(args.output_path, save_title + '_global.png') 
        show_plot(global_stats, args.num_steps, global_png_path) 
    
    if args.save_model:
        hynet_dir = os.path.join(args.output_path, save_title+'_'+'hynet.pt') 
        fc_dir = os.path.join(args.output_path, save_title+'_'+'fc.pt') 
        torch.save(hnet.hynet.state_dict(), hynet_dir)  
        torch.save(hnet.fc.state_dict(), fc_dir) 




        
        
       
        




        
 

