# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt


import time

import numpy as np

import copy

import os

import pytorch_lightning as pl
import torch
import torchvision

import utilities
from utilities import load_model

from methods import METHODS
from coopt import CLIENT_COMB

from coopt.server_client.client import Client
from coopt.server_client.server import Server

from coopt.server_client.data_utils import (
    fix_seed,
    index_extraction,
    split_dataset,
    subset_original,
    split_dataset_with_possible_overlap,
    get_median_index
)



class Opt_Data(torch.utils.data.Dataset):

    def __init__(
        self, x: torch.Tensor, optimal_data: torch.Tensor, y: torch.Tensor, transform, feature_dim
    ):
        self.x = x
        self.optimal_data = optimal_data.detach().cpu()
        self.y = y
        self.transform = transform
        self.feature_dim = feature_dim

    def __len__(self) -> int:
        return len(self.y)

    def __getitem__(self, idx: int):
        image = self.x[idx][0]
        image = self.transform(image)

        data = (
            (image, self.optimal_data[idx]),
            self.y[idx],
        )
        return data

def t_SNE_draw(features, save_path=None):
    fontdict = {
        "family": "Times New Roman",
        "weight": "bold",
        "size": 24,
    }
    ticks_fontdict = {
        "family": "Times New Roman",
        "weight": "bold",
        "size": 16,
    }

    print(len(features))
    tsne = TSNE(n_components=2, random_state=42)
    plt.figure(figsize=(10, 8))
    
    data = torch.cat(features, dim=0)
    data_2d = tsne.fit_transform(data)

    scatter_byol = plt.scatter(data_2d[: data.shape[0] // 2, 0], data_2d[: data.shape[0] // 2, 1],  alpha=0.7, s=40, c='#0984e3', marker='*', label='BYOL')
        
    scatter_ce = plt.scatter(data_2d[data.shape[0] // 2: , 0], data_2d[data.shape[0] // 2: , 1],  alpha=0.7, s=40, c='pink', label='SupCE')

    plt.gca().set_axis_off()
    
    plt.xlabel('', fontdict=fontdict)
    plt.ylabel('', fontdict=fontdict)
    
    plt.legend(loc='lower right', prop=fontdict)

    plt.xticks(**ticks_fontdict)
    plt.yticks(**ticks_fontdict)

    plt.savefig(f'{save_path}.pdf', format='pdf', bbox_inches='tight')
    plt.show()



def data_allocate(dataset: utilities.dataset.ImageFolder, args, e_round=1):
    
    p_overlap = args.p_overlap
    p_data_align = args.p_data_align
    dim_up = args.dim_up
    # print(dim_up)



    # judge if satisfy the conditions
    
    # print(p_data_align, dim_up)
    
    # client and server
    server = Server(args, dataset)
    
    client_comb = CLIENT_COMB[args.client_comb]
    clients = [
        Client(client_id, model_type, resolution, feature_dim, args)
        for client_id, (model_type, (resolution, feature_dim)) in enumerate(client_comb)
    ]

    clients_type = [client.model_type for client in clients]
    
    clients_feature_dim = [client.feature_dim for client in clients]
    assert p_data_align or dim_up or len(set(clients_feature_dim)) == 1, "Cannot set p_data_align=0 and dim_up=None (when types are not the same!) at the SAME time"
    
    max_clients_feature_dim = clients[clients_feature_dim.index(max(clients_feature_dim))].feature_dim


    model_uniform_values = []
    
    # data split
    split_index = split_dataset_with_possible_overlap(
                len(dataset), p_overlap, num=len(clients), e_round=e_round
            )

    index_feature_map = {
            data_idx: {client_id: [] for client_id in range(len(clients))}
            for data_idx in range(len(dataset))
        }
    
    # get optimal data 
    client_optimal_datas = []
    W_b = []

    temp_index = index_extraction(len(dataset), p_data=0.05)

    for client in clients:
        client.load_pretrained_model()
        client_optimal_data_first, uniform_value, W, b = client.optimize_data(
            # (dataset, temp_index), 
            (dataset, split_index[client.client_id]), 
            max_clients_feature_dim,
            dim_up=dim_up,
        )
        model_uniform_values.append(uniform_value)
        print(f"{client.model_type}: ", uniform_value)
        client_optimal_datas.append(client_optimal_data_first)
        W_b.append((W, b))

    # align
    
    # t_SNE_draw(client_optimal_datas, save_path='outputs/figures/fig3_b')
    
    
    if args.align_method == "align_to_best":
        align_client_index = model_uniform_values.index(min(model_uniform_values))
    elif args.align_method == "align_to_medium":
        align_client_index = get_median_index(model_uniform_values)
    elif args.align_method == "align_to_worst":
        align_client_index = model_uniform_values.index(max(model_uniform_values))

    # align_client_index = -1
    align_client = clients[align_client_index]
    
    
    align_model_type = align_client.model_type
    
    best_model_feature_dim = align_client.feature_dim
    
    if dim_up:
        align_feature_dim = max_clients_feature_dim
    elif not dim_up:
        align_feature_dim = best_model_feature_dim
    
    align_W, align_b = W_b[align_client_index]

    print(align_model_type)

    align_data_index = index_extraction(len(dataset), p_data=p_data_align)

    if len(align_data_index) == 0: 
        align_features = 0 
    
    else:
        align_features = align_client.optimize_data(
            (dataset, align_data_index),
            max_clients_feature_dim,
            dim_up=dim_up,
            align_W=align_W, align_b=align_b
        )

    client_optimal_datas_2 = []


    for client in clients:

        client.load_pretrained_model()

        if len(set(clients_type)) > 1 and client.model_type != align_model_type and len(align_data_index) > 0:
            client_optimal_data = client.align(
                client_optimal_datas[client.client_id], # feature_split
                align_feature_dim, 
                (dataset, align_data_index),
                align_features,
                dim_up=dim_up,
                align_W=W_b[client.client_id][0], align_b=W_b[client.client_id][1]
            )
            # print(client_optimal_data.shape)
        
        elif len(set(clients_type)) == 1 or client.model_type == align_model_type or len(align_data_index) == 0:
            client_optimal_data = client_optimal_datas[client.client_id]
        
        client_optimal_datas_2.append(client_optimal_data)
        
        for idx, feature in zip(split_index[client.client_id], client_optimal_data):
            index_feature_map[idx][client.client_id] = feature

    # t_SNE_draw(client_optimal_datas_2, save_path='outputs/figures/fig3_c')    
    
    # deal with conflict
    start_time = time.time()

    server.global_optimal_data = torch.randn(len(server.global_y), align_feature_dim) # re-define the size of global_optimal_data for alignment

    for data_idx, inner_dict in index_feature_map.items():
        
        best_client_idx = server.global_optimal_data_index[data_idx]
        
        data_feature = []

        for client_idx, feature in inner_dict.items():

            if len(feature) > 0:
                ## negative correlation
                data_feature.append(feature)

                if len(set(clients_type)) > 1:
                    if (
                        model_uniform_values[best_client_idx]
                        > model_uniform_values[client_idx] or best_client_idx not in inner_dict.keys() # need to be improved
                        
                    ):  
                        
                        best_client_idx = client_idx
  
                    server.global_optimal_data[data_idx] = inner_dict[
                        best_client_idx
                    ]
                    server.global_optimal_data_index[data_idx] = best_client_idx

                elif len(set(clients_type)) == 1:
                    server.global_optimal_data[data_idx] = inner_dict[
                        client_idx
                    ]
                    server.global_optimal_data_index[data_idx] = client_idx
                
        if args.avg_method == "equal_avg" and len(data_feature) > 1:
            
            # print(data_feature)
            server.global_optimal_data[data_idx] = torch.stack(data_feature).mean()
    
    print("deal with confilct time: ", time.time() - start_time)


    
    
    # return optimal dataset: utilities.dataset.ImageFolder for global model training
    _, TRANSFORM = METHODS[args.method]
    transform = TRANSFORM(args.dataset, args.input_size)
    x = copy.deepcopy(server.global_x) # has not been changed since server.global_x was initialized, global_x
    y = copy.deepcopy(server.global_y) # lables, global_y
    optimal_data = copy.deepcopy(server.global_optimal_data) 

    trainset = Opt_Data(x, optimal_data, y, transform, align_feature_dim)
    print(trainset.transform)
    return trainset

