import matplotlib
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import copy
import torch
import os
from tqdm import tqdm
import pickle
import random

from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar, Fin_CNNMnist, Fin_CNNCifar
from models.Fed import FedAvg, FedWeightedAvg
from models.test import test_img
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

import Coordinator
import Processor

import numpy as np

torch.autograd.set_detect_anomaly(True)

matplotlib.use('Agg')

data = 'cifar10'

class GetLocalDataset(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

def train(net_glob, cod, w_glob, args):
    # training
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []
    lr_glob = args.lr

   
    if args.all_clients:
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(len(cod.pcs.coordinator))]
   
    for iter in range(args.epochs):
        
        loss_locals = []
        if not args.all_clients:
             w_locals = []
        
        
        num_coordinators = len(cod.coordinator)
        num_selected_coordinators = int(args.frac * num_coordinators) 
        selected_coordinators = random.sample(cod.coordinator, num_selected_coordinators)  
        # print(cod.coordinator)
        print(selected_coordinators)

        for i, coord in tqdm(enumerate(selected_coordinators), desc=f"Epoch {iter+1} coordinators", total=num_selected_coordinators, leave=False, unit="coordinator"):
        # for i, coord in tqdm(enumerate(cod.coordinator), desc=f"Epoch {iter+1} coordinators", total=len(cod.coordinator), leave=False, unit="coordinator"):
           
            need_index = [cod.pcs.local_train_index[k] for k in coord]
            # print(need_index)
            local = LocalUpdate(args=args, dataset=pcs, idxs=np.hstack(need_index))
            w, loss, lr_glob_next = local.train(
                net=copy.deepcopy(net_glob).to(args.device), lr_glob=lr_glob) 
            if args.all_clients:
                w_locals[i] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
            # print(w_locals)
            print(loss_locals)

        # update global weights
        coordinator_data_sizes = [sum([len(cod.pcs.local_train_index[k]) for k in coord]) for coord in selected_coordinators]
        
        # w_glob = FedAvg(w_locals)

        w_glob = FedWeightedAvg(w_locals, coordinator_data_sizes)
        lr_glob = lr_glob_next


        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)
        if not os.path.exists(f"{data}_saved_params_glob"):
            os.makedirs(f"{data}_saved_params_glob")
        torch.save(net_glob.state_dict(), os.path.join(f"{data}_saved_params_glob", f'global_params.pth'))

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)
    
  
    loss_file_path = os.path.join('./save', f'train_loss_{data}.txt')
    np.savetxt(loss_file_path, loss_train, fmt='%.4f', header='Training Loss', comments='')

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('phase1_train_loss')
    plt.savefig('./save/fed_{}_{}_N{}_{}_C{}.png'.format(data, "Conv2d", args.num_users, "Dir(0.1)", args.frac))
    return net_glob

def train_to(net_glob, cod, w_glob, args):
    loss_train = []
    lr_glob = args.lr

    for iter in range(args.epochs):
        w_locals = []
        loss_locals = []

        
        #-----------------------------------------------------------------------------------
        num_clients = cod.pcs.size_device
        num_selected = int(args.frac * num_clients) 
       
        selected_clients = []
        remaining_clients = set(range(num_clients))  

        for coord in cod.coordinator:
            
            selected_client = random.choice(list(coord))
            selected_clients.append(selected_client)
            remaining_clients.discard(selected_client) 

       
        num_remaining_to_select = num_selected - len(selected_clients)
        if num_remaining_to_select > 0:
            selected_clients += random.sample(remaining_clients, num_remaining_to_select)

        # print(selected_clients)
        #-----------------------------------------------------------------------------------

        for i, coord in tqdm(enumerate(cod.coordinator), desc=f"Epoch {iter+1} coordinators", total=len(cod.coordinator), leave=False, unit="coordinator"):
            need_index = [cod.pcs.local_train_index[k] for k in coord]
            need_index_dict = {k: cod.pcs.local_train_index[k] for k in coord}
            local = LocalUpdate(args=args, dataset=pcs, idxs=np.hstack(need_index), coodinator_dict=need_index_dict, coord=coord)

           
            # w_fc1, loss = local.train_to(net=copy.deepcopy(net_glob).to(args.device))
            if data == "cifar10" or data == "cifar100":
                loss, lr_next_epoch = local.train_to_cifar(net=copy.deepcopy(net_glob).to(args.device), lr_glob=lr_glob, selected_clients=selected_clients)
            else:
                loss, lr_next_epoch = local.train_to_mnist(net=copy.deepcopy(net_glob).to(args.device), lr_glob=lr_glob, selected_clients=selected_clients)    

            # w_locals.append(copy.deepcopy(w_fc1))
            loss_locals.append(copy.deepcopy(loss))

       
        # w_glob_fc1 = FedAvg(w_locals)

       
        # net_glob.load_state_dict({**net_glob.state_dict(), **w_glob_fc1})

        
        loss_avg = sum(loss_locals) / len(loss_locals)
        print(f'Round {iter}, Average loss: {loss_avg:.3f}')
        loss_train.append(loss_avg)
        lr_glob = lr_next_epoch
        print(lr_glob)
    
    return net_glob


def test(net_glob, pcs, args, is_self_balanced, imbalanced_way):
    net_glob.eval()
    acc_train, loss_train = test_img(net_glob, pcs, args, is_self_balanced, imbalanced_way)
    pcs.type = 'test'
    acc_test, loss_test = test_img(net_glob, pcs, args, is_self_balanced, imbalanced_way)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))

def test_local_networks(net, pcs, args, net_glob):
    net.eval()
    net_glob.eval()
    pcs.type = 'test'

    output_dir = f"{data}_Management_Information"
    output_file = os.path.join(output_dir, 'P-FL_PM.txt')

    with open(output_file, 'w') as f:
       
        client_accuracies = []
        client_data_sizes = []  

        total_data_size = 0 

        for k in range(pcs.size_device):
            
            param_file = os.path.join(f'{data}_saved_params', f'client_{k}_params.pth')
            if os.path.exists(param_file):
                net.load_state_dict(torch.load(param_file))

            


            need_index = [pcs.local_test_index[k]]
            idxs=np.hstack(need_index)
            local_test_loader = DataLoader(GetLocalDataset(pcs, idxs), batch_size=args.bs)
            # local_test_dataset = torch.utils.data.TensorDataset(
            #     torch.tensor(pcs.local_test_feature[k], dtype=torch.float32),
            #     torch.tensor(pcs.local_test_label[k], dtype=torch.long)
            # )

           
            # local_test_loader = torch.utils.data.DataLoader(local_test_dataset, batch_size=args.bs, shuffle=False)
        
           
            client_data_size = len(idxs)
            client_data_sizes.append(client_data_size)

            total_data_size += client_data_size
            
            correct = 0
            correct_glob = 0
            total = 0

           
            with torch.no_grad():
                for batch_idx, (images, labels) in enumerate(local_test_loader):
                    images, labels = images.to(args.device), labels.to(args.device)
                    labels = labels.long()
                    
                    
                    log_probs_fc2, log_probs_fc3 = net(images)

                    log_glob = net_glob(images)
                
                    
                    _, predicted = torch.max(log_probs_fc3, 1)

                    _, predicted_glob = torch.max(log_glob, 1)
                
                    
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    correct_glob += (predicted_glob == labels).sum().item()
            
           
            accuracy = 100.0 * correct / total

            client_accuracies.append(accuracy)

            f.write(f'Client {k}, Accuracy: {accuracy:.2f}%\n')

           
            if args.verbose:
                print(f'Client {k} , Accuracy: {accuracy:.2f}%')

        
        avg_client_accuracy = sum(client_accuracies) / len(client_accuracies)
        print(avg_client_accuracy)

       
        weighted_accuracy = sum(client_accuracies[i] * client_data_sizes[i] for i in range(len(client_accuracies))) / total_data_size

        # f.write(f'Average Client Accuracy: {avg_client_accuracy:.2f}%\n')
        f.write(f'Weighted Average Client Accuracy: {weighted_accuracy:.2f}%\n')

    return weighted_accuracy

def test_global_networks(net, pcs, args):
    net.eval()
    pcs.type = 'test'

    output_dir = f"{data}_Management_Information"
    output_file = os.path.join(output_dir, 'P-FL_GM.txt')

    with open(output_file, 'w') as f:
        
        client_accuracies = []

        count = 0
        for k in range(pcs.size_device):
            
            param_file = os.path.join(f'{data}_saved_params_glob', f'global_params.pth')
            if os.path.exists(param_file):
                net.load_state_dict(torch.load(param_file))
            
            net.eval()

           

            need_index = [pcs.local_test_index[k]]
            idxs=np.hstack(need_index)
            local_test_loader = DataLoader(GetLocalDataset(pcs, idxs), batch_size=args.bs)


            # local_test_dataset = torch.utils.data.TensorDataset(
            #     torch.tensor(pcs.local_test_feature[k]),
            #     torch.tensor(pcs.local_test_label[k])
            # )

           
            # local_test_loader = DataLoader(local_test_dataset, batch_size=args.bs)
            # local_test_loader = DataLoader(pcs, batch_size=args.bs)
        
            correct = 0
            total = 0

            
            with torch.no_grad():
                for batch_idx, (images, labels) in enumerate(local_test_loader):
                    images, labels = images.to(args.device), labels.to(args.device)
                    labels = labels.long()
                    
                    
                    log_probs_fc3 = net(images)
                
                    
                    predicted = log_probs_fc3.data.max(1, keepdim=True)[1]
                
                    
                    total += labels.size(0)
                    correct += predicted.eq(labels.data.view_as(predicted)).long().cpu().sum()

            
            # accuracy = 100.0 * correct / len(pcs.local_test_feature[k])
            accuracy = 100.0 * correct / len(local_test_loader.dataset)
            
            count += len(local_test_loader.dataset)
            
            print(len(local_test_loader.dataset))

            client_accuracies.append(accuracy)

            f.write(f'Client {k}, Accuracy: {accuracy:.2f}%\n')

            
            if args.verbose:
                print(f'Client {k} , Accuracy: {accuracy:.2f}%')
        
        print(count)

        
        avg_client_accuracy = sum(client_accuracies) / len(client_accuracies)

        f.write(f'Average Client Accuracy: {avg_client_accuracy:.2f}%\n')

    return avg_client_accuracy


def save_pcs(pcs, filename='pcs.pkl'):
    
    
    save_dir = f'{data}_Management_Information'
    
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    with open(os.path.join(save_dir, filename), 'wb') as f:
        pickle.dump(pcs, f)
    print(f"Processor saved to {filename}")


def load_pcs(filename='pcs.pkl'):
    with open(os.path.join(f'{data}_Management_Information', filename), 'rb') as f:
        pcs = pickle.load(f)
    print(f"Processor loaded from {filename}")
    return pcs

def extract_features(model, data_loader, feature_extractor=True):
    features = []
    labels = []
    
    pcs.type = "test"
    
    
    with torch.no_grad(): 
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = data.to(args.device), target.to(args.device)
            x = model.pool(F.relu(model.bn1(model.conv1(data))))
            x = model.pool(F.relu(model.bn2(model.conv2(x))))
            x = model.pool(F.relu(model.bn3(model.conv3(x))))
            x = x.view(-1, 64 * 4 * 4) 
            
            if feature_extractor: 
                features.append(x.cpu().numpy())
            else: 
                x = model.dropout(F.relu(model.fc1(x)))  
                features.append(x.cpu().numpy())
            
            labels.append(target.cpu().numpy())

    features = np.concatenate(features)
    labels = np.concatenate(labels)
    return features, labels

def visualize_features(net_glob, pcs):
   
    net_glob.eval()

    data_loader = DataLoader(pcs, batch_size=args.bs)
    
    
    features_extractor, labels_extractor = extract_features(net_glob, data_loader, feature_extractor=True)
    features_combined, labels_combined = extract_features(net_glob, data_loader, feature_extractor=False)

   
    tsne_extractor = TSNE(n_components=2, random_state=42)
    features_embedded_extractor = tsne_extractor.fit_transform(features_extractor)

    tsne_combined = TSNE(n_components=2, random_state=42)
    features_embedded_combined = tsne_combined.fit_transform(features_combined)

   
    colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'orange', 'purple', 'brown']

    
    extractor_colors = [colors[label] for label in labels_extractor]
    combined_colors = [colors[label] for label in labels_combined]

   
    # plt.figure(figsize=(12, 6))

   
    # plt.subplot(1, 2, 1)
    # scatter1 = plt.scatter(features_embedded_extractor[:, 0], features_embedded_extractor[:, 1], c=extractor_colors)
    # plt.title("Feature Extractor")
    # # plt.xlabel("t-SNE Component 1")
    # # plt.ylabel("t-SNE Component 2")

    # plt.savefig('tsne_visualization_extractor.pdf')

    
    # plt.subplot(1, 2, 2)
    # scatter2 = plt.scatter(features_embedded_combined[:, 0], features_embedded_combined[:, 1], c=combined_colors)
    # plt.title("Feature Extractor + Filter")
    # # plt.xlabel("t-SNE Component 1")
    # # plt.ylabel("t-SNE Component 2")

    # plt.savefig('tsne_visualization_combined.pdf')


    # plt.show()

    
    plt.figure(figsize=(6, 6))
    scatter1 = plt.scatter(features_embedded_extractor[:, 0], features_embedded_extractor[:, 1], c=extractor_colors)
    # plt.title("Feature Extractor")
    plt.axis('off')  
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0) 
    plt.savefig('tsne_visualization_extractor.pdf')
    plt.close()  

    
    plt.figure(figsize=(6, 6)) 
    scatter2 = plt.scatter(features_embedded_combined[:, 0], features_embedded_combined[:, 1], c=combined_colors)
    # plt.title("Feature Extractor + Filter")
    plt.axis('off')  
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0) 
    plt.savefig('tsne_visualization_combined.pdf')
    plt.close() 

def visualize_client_features(net, pcs, num_clients):
    
    net.eval()

    
    colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'orange', 'purple', 'brown']  
    label_map = {i: colors[i] for i in range(10)}  

    plt.figure(figsize=(6, 6))

    all_features = []
    all_labels = []

    for k in range(num_clients):
       
        param_file = os.path.join(f'{data}_saved_params', f'client_{k}_params.pth')
        if os.path.exists(param_file):
            net.load_state_dict(torch.load(param_file))

           
            need_index = [pcs.local_test_index[k]]
            idxs = np.hstack(need_index)
            local_test_loader = DataLoader(GetLocalDataset(pcs, idxs), batch_size=args.bs)

            
            features, labels = extract_features(net, local_test_loader, feature_extractor=False)

           
            all_features.append(features)
            all_labels.append(labels)

    
    all_features = np.vstack(all_features)
    all_labels = np.hstack(all_labels)

    
    tsne = TSNE(n_components=2, random_state=42)
    features_embedded = tsne.fit_transform(all_features)

    
    for label in range(10):  
        label_indices = np.where(all_labels == label)[0]
        plt.scatter(features_embedded[label_indices, 0], features_embedded[label_indices, 1], 
                    c=label_map[label], alpha=0.5)
    
    plt.axis('off') 
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  

    # plt.title("t-SNE Visualization of Client Features")
    # plt.xlabel("t-SNE Component 1")
    # plt.ylabel("t-SNE Component 2")
    # plt.legend(loc='best')
    plt.savefig('client_features_tsne.pdf')
    plt.show()


def moving_average(data, window_size):
    return np.convolve(data, np.ones(window_size) / window_size, mode='valid')

if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    # new instances for Processor and Coordinator
    pcs = Processor.Processor()
    pcs.get_input(data)
    imbalanced_way = ""
    pcs.gen_local_imbalance_hsu(args.num_users, 5000, 0.3)
    imbalanced_way = "local"
    
    # build new model
    save_pcs(pcs, filename='pcs.pkl')

    img_size = pcs[0][0].shape

    net_glob = None
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)
    net_glob.train()
    # copy weights
    w_glob = net_glob.state_dict()
    # self balanced
    cod = Coordinator.Coordinator(pcs)
    cod.assign_clients()
    pcs.type = "train"
    train(net_glob, cod, w_glob, args)
    test(net_glob, pcs, args,  "self_balanced", imbalanced_way)
    test_global_networks(net_glob, pcs, args)
    
    
    # pcs = Processor.Processor()
    # pcs.get_input(data)
    # pcs.gen_local_imbalance_hsu(args.num_users, 5000, 0.1)

    # net_glob = CNNCifar(args=args).to(args.device)
    

    # param_file = os.path.join(f"{data}_saved_params_glob", f'global_params.pth')
    # if os.path.exists(param_file):
    #             net_glob.load_state_dict(torch.load(param_file))

    # net_glob.train()
    # w_glob = net_glob.state_dict()
    
    # pcs = load_pcs(filename='pcs.pkl')
    # save_pcs(pcs, filename='pcs.pkl')

    # cod = Coordinator.Coordinator(pcs)
    # cod.assign_clients()
    # pcs.type = "train"
    # train(net_glob, cod, w_glob, args)

    # test(net_glob, pcs, args,  "self_balanced", "local")
    # test_global_networks(net_glob, pcs, args)

    net = Fin_CNNCifar(args=args).to(args.device)


    if data == 'cifar10' or data == 'cifar100':
        global_model_weights = copy.deepcopy(net_glob.state_dict())
        conv1_weights = copy.deepcopy(net_glob.conv1.state_dict())
        bn1_weights = copy.deepcopy(net_glob.bn1.state_dict())
        conv2_weights = copy.deepcopy(net_glob.conv2.state_dict())
        bn2_weights = copy.deepcopy(net_glob.bn2.state_dict())
        fc1_weights = copy.deepcopy(net_glob.fc1.state_dict())
        fc2_weights = copy.deepcopy(net_glob.fc2.state_dict())

        
        #-----------------------------------------------------------
        conv3_weights = copy.deepcopy(net_glob.conv3.state_dict())
        bn3_weights = copy.deepcopy(net_glob.bn3.state_dict())
        #-----------------------------------------------------------

        global_model = CNNCifar(args=args).to(args.device)
        global_model.load_state_dict(global_model_weights)

        net = Fin_CNNCifar(args=args).to(args.device)
        net.conv1.load_state_dict(conv1_weights)
        net.bn1.load_state_dict(bn1_weights)
        net.conv2.load_state_dict(conv2_weights)
        net.bn2.load_state_dict(bn2_weights)
        net.fc1.load_state_dict(fc1_weights)
        net.fc2.load_state_dict(fc2_weights)

        
        #-----------------------------------------------------------
        net.conv3.load_state_dict(conv3_weights)
        net.bn3.load_state_dict(bn3_weights)
        #-----------------------------------------------------------
    else:
        global_model_weights = copy.deepcopy(net_glob.state_dict())
        conv1_weights = copy.deepcopy(net_glob.conv1.state_dict())
        fc1_weights = copy.deepcopy(net_glob.fc1.state_dict())
        fc2_weights = copy.deepcopy(net_glob.fc2.state_dict())

        
        global_model = CNNMnist(args=args).to(args.device)
        global_model.load_state_dict(global_model_weights)

        net = Fin_CNNMnist(args=args).to(args.device)
        net.conv1.load_state_dict(conv1_weights)
        net.fc1.load_state_dict(fc1_weights)
        net.fc2.load_state_dict(fc2_weights)
    
    torch.save(global_model.state_dict(), os.path.join(f"{data}_saved_params_glob", f'global_params.pth'))

    print(net)
    
    net.train()
    w_glob = net.state_dict()

    
    cod = Coordinator.Coordinator(pcs)
    cod.assign_clients_to()
    pcs.type = "train"
    train_to(net, cod, w_glob, args)

    

    test_local_networks(net, pcs, args, net_glob)
    
    test_global_networks(net_glob, pcs, args)

    #------------------------------------------------------------------------------------------------


    # test(net_glob, pcs, args,  "self_balanced", "local")

    # net = Fin_CNNMnist(args=args).to(args.device)

    # param_file = os.path.join(f"{data}_saved_params", f'global_params.pth')
    # if os.path.exists(param_file):
    #             net_glob.load_state_dict(torch.load(param_file))
    
    # pcs = load_pcs(filename='pcs.pkl')

    
    # visualize_features(net_glob, pcs)

    # visualize_client_features(net, pcs, pcs.size_device)

    # window_size = 3

    
    # fedavg_loss = np.loadtxt('./save/train_loss_cifar10_FedAvg.txt', skiprows=1)
    # fed3_2p_loss = np.loadtxt('./save/train_loss_cifar10_Fed3+2p.txt', skiprows=1)

    # fed3_2p_loss_smooth = moving_average(fed3_2p_loss, window_size)

   
    # epochs_fedavg = len(fedavg_loss)
    # epochs_fed3_2p = len(fed3_2p_loss)

  
    # plt.figure()
    # plt.plot(range(epochs_fedavg), fedavg_loss, label='FedAvg', color='blue')
    # plt.plot(range(epochs_fed3_2p), fed3_2p_loss, label='Fed3+2p', color='green')
    # # plt.plot(range(len(fed3_2p_loss_smooth)), fed3_2p_loss_smooth, label='Fed3+2p (Smoothed)', color='green')

   
    # plt.xlabel('Epochs')
    # plt.ylabel('Train Loss')
    # plt.title('Training Loss Comparison')
    # plt.legend()

  
    # plt.savefig('./save/train_loss_comparison.pdf')