import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm

from sklearn.feature_extraction.text import TfidfVectorizer
from gensim.models import Word2Vec
import numpy as np
import gensim.downloader
import gensim.downloader as api
from model.GFT_encoder import Encoder, InnerProductDecoder
from model.GFT_ft_model import TaskModel
from model.GFT_pt_model import PretrainModel
from model.GFT_vq import VectorQuantize
from torch.optim import AdamW
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import mask_feature, dropout_adj
from torch_geometric.data import Data
import copy
from torch_geometric.utils import subgraph
from utils.basic_utils import accuracy, index_to_mask
from utils.basic_utils import EarlyStopping, Logger
from utils.basic_utils import get_loader, seed_everything, get_preprocess,sample_proto_instances_for_graph
from task.node_cls import ft_node, eval_node
from task.link_pre import ft_link, eval_link
from task.graph_cls import ft_graph, eval_graph
from torch_geometric.loader import DataLoader
from model.GFT_ft_model import efficient_compute_class_prototypes
from tqdm import tqdm
from utils.partition_utils import merge_raw


def get_class_prototypes(z, y, num_classes_in_total):
    if isinstance(y, dict):
        # This works for graph classification with multiple binary tasks

        n_task = len(y)
        flat_y = np.array([])

        for task, labels in y.items():
            flat_y = np.concatenate((flat_y, task * 2 + labels), axis=0)
        flat_y = torch.tensor(flat_y, dtype=torch.long, device=z.device)
        
        unique_labels = torch.unique(flat_y)
        prototypes = []

        for label in unique_labels:
            indices = (flat_y == label)
            proto = z[indices].mean(dim=0)
            prototypes.append(proto)
        return prototypes

    else:
        # This works for node and link classification
        return efficient_compute_class_prototypes(
            z, y, num_classes_in_total, return_head_first=False
        )


    
    
class Client:
    
    def __init__(self, args, client_id, data_tag, device, finetune_params=None, mode="pretrain"):
        self.args = args
        self.client_id = client_id
        self.mode = mode
        
        if data_tag.name.lower() not in ["chemhiv", "chemblpre", "chempcba"]:
            self.data_tag = data_tag.to(device)
        else:
            self.data_tag = data_tag
            if mode == "pretrain":
                self.data_tag = self.data_tag.to(device)

        self.device = device
        
        
        self.dim = 768
          
        # encoder-decoder
        self.encoder = Encoder(
            input_dim=self.dim,
            hidden_dim=self.dim,
            activation=nn.ReLU,
            num_layers=2,
            backbone="mysage",
            normalize="batch",
            dropout=0.15
            )

        self.vq = VectorQuantize(
            dim=self.dim,
            codebook_size=128,
            codebook_dim=self.dim,
            heads=4,
            separate_codebook_per_head=True,
            decay=0.8,
            commitment_weight=10,
            use_cosine_sim=True,  # Cosine Codebook Works, Euclidean Codebook Collapses
            orthogonal_reg_weight=1,
            orthogonal_reg_max_codes=32,
            orthogonal_reg_active_codes_only=False,
            kmeans_init=False,
            ema_update=False,
        )

        self.feat_recon_decoder = nn.Linear(self.dim, self.dim)
        self.topo_recon_decoder = InnerProductDecoder(hidden_dim=self.dim, output_dim=self.dim)
        self.topo_sem_recon_decoder = nn.Linear(self.dim * 2, self.dim)
        
        
        # pretrain model
        self.pretrain_model = PretrainModel(
            encoder=self.encoder, vq=self.vq,
            feat_recon_decoder=self.feat_recon_decoder,
            topo_recon_decoder=self.topo_recon_decoder,
            topo_sem_recon_decoder=self.topo_sem_recon_decoder,
        ).to(device)

        # pretrain model params
        self.pretrain_epochs = 2
        self.pretrain_batch_size = 1024
        self.feat_p=0.2
        self.edge_p=0.2
        self.topo_recon_ratio=0.1
        self.feat_lambda=100
        self.topo_lambda=0.01
        self.topo_sem_lambda=100
        self.sem_lambda=1
        self.sem_encoder_decay=0.99
        self.pretrain_lr=1e-4
        self.separate_codebook_per_head=True
        self.separate_decoder_for_each_head=True
        self.use_cosine_sim=True
        self.use_z_in_predict=True
        self.no_lin_clf=False
        self.no_proto_clf=False
        if finetune_params is None:
            self.lambda_proto = 0
            self.lambda_act = 0
            self.num_instances_per_class = 0
            self.lambda_proto=0
            self.lambda_act=0
            self.num_instances_per_class=0
            self.trade_off=0
            self.finetune_batch_size=0
            self.finetune_lr = 0
            self.early_stop = 0
            self.finetune_epochs = 0
        else:
            self.lambda_proto=finetune_params["lambda_proto"]
            self.lambda_act=finetune_params["lambda_act"]
            self.num_instances_per_class=finetune_params["num_instances_per_class"]
            self.lambda_proto=finetune_params["lambda_proto"]
            self.lambda_act=finetune_params["lambda_act"]
            self.num_instances_per_class=finetune_params["num_instances_per_class"]
            self.trade_off=finetune_params["trade_off"]
            self.finetune_batch_size=finetune_params["batch_size"]
            self.finetune_lr = finetune_params["finetune_lr"]
            self.early_stop = finetune_params["early_stop"]
            self.finetune_epochs = finetune_params["finetune_epochs"]
        self.setting="standard"
        self.query_node_code_first=False
        self.num_prompts = 100
        
        self.trainable_prompt_nodes = nn.Parameter(
            torch.randn(self.num_prompts, 768),
            requires_grad=True
        ).to(self.device)
        


    
        
        
        
                
        
    def pretrain(self):
        if True: # debug
            self.pretrain_model.train()
            optimizer = AdamW(self.pretrain_model.parameters(), lr=self.pretrain_lr, weight_decay=1e-5)
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: (1 + np.cos(epoch * np.pi / self.pretrain_epochs)) * 0.5)

        
        
            # self.data_tag.node_text_feat = 
        
        
        
            for i in range(1, self.pretrain_epochs + 1):
                batch_size = self.pretrain_batch_size
                total_idx = torch.arange(0, self.data_tag.x.shape[0]).long()
                loader = NeighborLoader(self.data_tag, input_nodes=total_idx,
                                        num_neighbors=[10] * 2,
                                        batch_size=batch_size, shuffle=True)
                
                total_feat_recon_loss = 0
                total_topo_recon_loss = 0
                total_topo_sem_recon_loss = 0
                total_sem_recon_loss = 0
                total_commit_loss = 0
                total_loss = 0
                batch_count = 0
                for batch_data in loader:
                    data_x_is_idx = batch_data.x.size(0) != batch_data.node_text_feat.size(0)

                    if data_x_is_idx:
                        x = batch_data.node_text_feat[batch_data.x].to(self.device)
                    else:
                        x = batch_data.node_text_feat.to(self.device)
                    
                    # print(f"client_{self.client_id}")
                    # print(batch_data.x)    
                    
                    edge_index = batch_data.edge_index.to(self.device)
                    edge_attr = batch_data.edge_text_feat[batch_data.xe].to(self.device)
                    graph = [x, edge_index, edge_attr]

                    aug_x, _ = mask_feature(x, p=self.feat_p)
                    aug_edge_index, aug_edge_attr = dropout_adj(
                        edge_index, edge_attr, p=self.edge_p, force_undirected=True, num_nodes=x.size(0)
                    )
                    aug_graph = [aug_x, aug_edge_index, aug_edge_attr]

                    z, quantize, indices, losses = self.pretrain_model(
                        aug_graph, graph, self.topo_recon_ratio, bs=batch_size, no_codebook=False
                    )

                    feat_recon_loss = self.feat_lambda * losses['feat_recon_loss']
                    topo_recon_loss = self.topo_lambda * losses['topo_recon_loss']
                    topo_sem_recon_loss = self.topo_sem_lambda * losses['topo_sem_recon_loss']
                    sem_recon_loss = self.sem_lambda * losses['sem_recon_loss']
                    commit_loss = losses['commit_loss']
                    loss = feat_recon_loss + topo_recon_loss + topo_sem_recon_loss + sem_recon_loss + commit_loss

                    optimizer.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_norm_(self.pretrain_model.parameters(), 1.0)
                    optimizer.step()
                    if scheduler:
                        scheduler.step()
                    self.pretrain_model.ema_update_sem_encoder(decay=self.sem_encoder_decay)

                    losses = {
                        'losses/feat_recon_loss': feat_recon_loss.item(),
                        'losses/topo_recon_loss': topo_recon_loss.item(),
                        'losses/topo_sem_recon_loss': topo_sem_recon_loss.item(),
                        'losses/sem_recon_loss': sem_recon_loss.item(),
                        'losses/commit_loss': commit_loss.item(),
                        'losses/loss': loss.item(),
                    }
                    total_feat_recon_loss += feat_recon_loss.item()
                    total_topo_recon_loss += topo_recon_loss.item()
                    total_topo_sem_recon_loss += topo_sem_recon_loss.item()
                    total_sem_recon_loss += sem_recon_loss.item()
                    total_commit_loss += commit_loss.item()
                    total_loss += loss.item()
                    batch_count += 1
                    
                avg_feat_recon_loss = total_feat_recon_loss / batch_count
                avg_topo_recon_loss = total_topo_recon_loss / batch_count
                avg_topo_sem_recon_loss = total_topo_sem_recon_loss / batch_count
                avg_sem_recon_loss = total_sem_recon_loss / batch_count
                avg_commit_loss = total_commit_loss / batch_count
                avg_loss = total_loss / batch_count

                print({
                    'avg_losses/feat_recon_loss': avg_feat_recon_loss,
                    'avg_losses/topo_recon_loss': avg_topo_recon_loss,
                    'avg_losses/topo_sem_recon_loss': avg_topo_sem_recon_loss,
                    'avg_losses/sem_recon_loss': avg_sem_recon_loss,
                    'avg_losses/commit_loss': avg_commit_loss,
                    'avg_losses/loss': avg_loss,
                })    
                    
            self.pretrain_model.eval()
       
       
       
        # domain prototypes
        total_idx = torch.arange(0, self.data_tag.x.shape[0]).long()
        
        # print(f"client_{self.client_id}: total_idx: {self.data_tag.x.shape[0]}")
        loader = NeighborLoader(self.data_tag, input_nodes=total_idx,
                                num_neighbors=[10] * 2,
                                batch_size=batch_size, shuffle=True)
        
        z_all = []
        
        
        for batch_data in loader:
            data_x_is_idx = batch_data.x.size(0) != batch_data.node_text_feat.size(0)

            if data_x_is_idx:
                x = batch_data.node_text_feat[batch_data.x].to(self.device)
            else:
                x = batch_data.node_text_feat.to(self.device)
            edge_index = batch_data.edge_index.to(self.device)
            edge_attr = batch_data.edge_text_feat[batch_data.xe].to(self.device)
                        
            z = self.pretrain_model.encode(x, edge_index, edge_attr).detach().cpu()

            z_all.append(z)
                
        self.domain_prototypes = torch.vstack(z_all).mean(dim=0).view(-1, 768).to(self.device)
        
      




    def finetune(self, standard=3):
        num_tasks = self.data_tag.num_tasks
        task = self.data_tag.task
          
        train_loader = None
        val_loader = None
        test_loader = None
        subgraph_loader = None
        process = get_preprocess(task)
        dataset = self.data_tag
        dataset = process(dataset)


        labels = self.data_tag.y

        
        num_classes = num_tasks if task == "graph_cls" else self.data_tag.num_global_classes
            
        split = {"train": dataset.train_mask, "valid": dataset.val_mask, "test": dataset.test_mask}
        
        
        self.logger = Logger()
        for idx in range(standard): # debug: only 3 split
            seed_everything(idx)
            if self.setting == "standard":
                split = split
            # elif self.setting in ["few_shot", "zero_shot", "in_context"]:
            #     if task in ["node", "link"]:
            #         split = get_split(split, labels, params)
            #     elif task == "graph":
            #         split = get_split_graph(split, labels, params)
            else:
                raise ValueError("Invalid Setting")
            
            task_model = TaskModel(
                encoder=copy.deepcopy(self.encoder),
                vq=copy.deepcopy(self.vq),
                num_classes=num_classes,
                separate_decoder_for_each_head=self.separate_decoder_for_each_head,
                use_z_in_predict=self.use_z_in_predict,
                use_cosine_sim=self.use_cosine_sim,
                lambda_proto=self.lambda_proto,
                lambda_act=self.lambda_act,
                trade_off=self.trade_off,
                num_instances_per_class=self.num_instances_per_class,
                ).to(self.device)

            opt_params = task_model.parameters()
            task_opt = AdamW(opt_params, lr=self.finetune_lr)
            stopper = EarlyStopping(patience=self.early_stop)

            if self.finetune_batch_size != 0 and task in ["node_cls", "link_pre"]:
                train_loader, subgraph_loader = get_loader(dataset, split, labels, task, self.finetune_batch_size)
            elif self.finetune_batch_size != 0 and task == "graph_cls":
                train_loader, val_loader, test_loader = get_loader(dataset, split, labels, task, self.finetune_batch_size)
            finetune = get_ft(task)
            evaluate = get_eval(task)
        
            pbar = tqdm(range(self.finetune_epochs), desc=f"Finetuning - Dataset {self.data_tag.name} - Standard {idx} - {self.data_tag.task}")
                    
            for epoch in pbar:
                loss = finetune(
                    model=task_model,
                    dataset=dataset if task in ["node_cls", "link_pre"] else dataset,
                    loader=train_loader,
                    optimizer=task_opt,
                    split=split,
                    labels=labels,
                    num_classes=num_classes,
                    no_proto_clf=self.no_proto_clf,
                    no_lin_clf=self.no_lin_clf,
                    use_z_in_predict=self.use_z_in_predict,
                    query_node_code_first=self.query_node_code_first,
                    lambda_proto=self.lambda_proto,
                    lambda_act=self.lambda_act,
                    num_instances_per_class=self.num_instances_per_class,
                    num_neighbors=[30] * 2,
                )

                result = evaluate(
                    model=task_model,
                    dataset= dataset if task in ["node_cls", "link_pre"] else dataset,
                    loader= subgraph_loader if task in ["node_cls", "link_pre"] else [train_loader, val_loader, test_loader],
                    split=split,
                    labels=labels,
                    num_classes=num_classes,
                    no_proto_clf=self.no_proto_clf,
                    no_lin_clf=self.no_lin_clf,
                    use_z_in_predict=self.use_z_in_predict,
                    query_node_code_first=self.query_node_code_first,
                    num_instances_per_class=self.num_instances_per_class,
                    task=task,
                    num_neighbors=[-1] * 2,
                )

                is_stop = stopper(result)
                self.logger.log(idx, epoch, loss, result)
                if is_stop:
                    print("Early Stopping at Epoch:", epoch)
                    break
                # if epoch%50==0:
                    # print("Epoch:", epoch)
            single_best = self.logger.get_single_best(idx)
        best = self.logger.get_best()
        print({
            "final/train": "{:.2f} ± {:.2f}".format(best['train']['mean'], best['train']['std']),
            "final/val": "{:.2f} ± {:.2f}".format(best['val']['mean'], best['val']['std']),
            "final/test": "{:.2f} ± {:.2f}".format(best['test']['mean'], best['test']['std']),
            "final/train_mean": best['train']['mean'],
            "final/val_mean": best['val']['mean'],
            "final/test_mean": best['test']['mean'],
            "final/train_std": best['train']['std'],
            "final/val_std": best['val']['std'],
            "final/test_std": best['test']['std'],
        })
    
        
        

    def get_pretrain_model(self):
        local_message = {
            "num_samples": 1,
            "weight": list(self.pretrain_model.parameters())
        }
        return local_message

    def set_pretrain_model(self, global_message): 
        if type(global_message) is dict:
            with torch.no_grad():
                for (name, local_param), global_param in zip(self.pretrain_model.named_parameters(), global_message[self.client_id]):
                    if name.startswith('vq.'): 

                        local_param.data.copy_(global_param)
        else:
            with torch.no_grad():
                for (local_param, global_param) in zip(self.pretrain_model.parameters(), global_message):
                    local_param.data.copy_(global_param)



def get_ft(c_task):
    task = c_task

    if task == "node_cls":
        return ft_node
    elif task == "link_pre":
        return ft_link
    elif task == "graph_cls":
        return ft_graph
    else:
        raise ValueError("Invalid Task")            
                
def get_eval(c_task):
    task = c_task

    if task == "node_cls":
        return eval_node
    elif task == "link_pre":
        return eval_link
    elif task == "graph_cls":
        return eval_graph
    else:
        raise ValueError("Invalid Task")            