import os
import os.path as osp
import wandb
import warnings
from copy import deepcopy
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW

from dataset.process_datasets import get_finetune_graph
from model.encoder import Encoder
from model.vq import VectorQuantize
from model.ft_model import TaskModel, GraphLLM, GQATaskModel
from utils.loader import get_loader
from utils.early_stop import EarlyStopping
from utils.logger import Logger
from utils.args import get_args_finetune
from utils.preprocess import pre_node, pre_link, pre_graph, pre_GQA
from utils.others import seed_everything
from utils.splitter import get_split, get_split_graph

from task.node import ft_node, eval_node
from task.link import ft_link, eval_link
from task.graph import ft_graph, eval_graph
from task.GQA import ft_GQA, eval_GQA

from data.GQA.GQA_utils import llama_model_path  # type: ignore

warnings.filterwarnings("ignore")

dataset2task = {
    "cora": "node",
    "citeseer": "node",
    "pubmed": "node",
    "arxiv": "node",
    "wikics": "node",
    "WN18RR": "link",
    "FB15K237": "link",
    "chemhiv": "graph",
    "chemcyp450": "graph",
    "chembace": "graph",
    "chembbbp": "graph",
    "chemmuv": "graph",
    "chempcba": "graph",
    "chemtoxcast": "graph",
    "chemtox21": "graph",
    "bookhis": "node",
    "elecomp": "node",
    "elephoto": "node",
    "sportsfit": "node",
    "products": "node",
    "expla_graph": "GQA",
    "scene_graphs": "GQA",
    "reddit": "node",
    "instagram": "node",
}


if __name__ == "__main__":
    params = get_args_finetune()

    dataset_name = params["finetune_dataset"]
    task = dataset2task[dataset_name]
    params['task'] = task

    if params["setting"] == "few_shot" and task == 'graph':
        params['n_way'] = 2
        params['num_instances_per_class'] = params['n_train']

    # At least use a classifier
    assert not (params['no_lin_clf'] and params['no_proto_clf'])
    if params['no_lin_clf']:
        params['trade_off'] = 0
    if params['no_proto_clf']:
        params['trade_off'] = 1

    if params['wandb'] is True:
        wandb.init(
            project="GFM-Finetune",
            name="{} - Pretrain Epoch {}".format(
                params["finetune_dataset"],
                params["pretrain_epochs"]
            ),
            config=params,
            mode='online',
            tags=[params['setting']],
        )
        
        params = dict(wandb.config)

    seed_everything(params["seed"])
    
    params["activation"] = nn.ReLU if params["activation"] == "relu" else nn.LeakyReLU
    device = torch.device(f"cuda:{params['gpu']}") if torch.cuda.is_available() else torch.device("cpu")
    task = params["task"]

    if task == 'node':
        preprocess, finetune, evaluate = pre_node, ft_node, eval_node
    elif task == 'link':
        preprocess, finetune, evaluate = pre_link, ft_link, eval_link
    elif task == 'graph':
        preprocess, finetune, evaluate = pre_graph, ft_graph, eval_graph
    elif task == 'GQA':
        preprocess, finetune, evaluate = pre_GQA, ft_GQA, eval_GQA
        params['llm_model_path'] = llama_model_path[params['llm_model_name']]
    else:
        raise NotImplementedError('The task is not implemented')

    dataset, splits, labels, num_classes, questions, desc = get_finetune_graph(
                                                    task,
                                                    params['data_path'], 
                                                    params["finetune_dataset"],
                                                    params["graph_llm_name"],
                                                    params["llm_b_size"],
                                                    params["root_path"]
    )

    dataset = preprocess(dataset)
    if task in ["node", "link"]:
        dataset = dataset[0]
        dataset.y = labels
    
    if isinstance(splits, list):
        pass
    elif isinstance(splits, dict):
        splits = [splits] * params["repeat"]

    encoder = Encoder(
        params["input_dim"],
        params["hidden_dim"],
        params["activation"],
        params["num_layers"],
        params["normalize"],
        params["dropout"],
    )

    vq = VectorQuantize(
        params["hidden_dim"],
        params["codebook_size"],
        params["num_expert"],
        params["codebook_heads"],
        params["topk"]
    )
    
    # Load Pretrained Model
    path = osp.join(params['model_path'], "codebook_size_{}_layer_{}_pretrain_on_{}_seed_{}".format(
        params["codebook_size"], params["num_layers"], params["pretrain_dataset"], params['seed']
    ))

    encoder.load_state_dict(torch.load(osp.join(path, f'encoder_{params["pretrain_epochs"]}.pt')))
    vq.load_state_dict(torch.load(osp.join(path, f'vq_{params["pretrain_epochs"]}.pt')))

    print("Loader the pretrained encoder and vq model from {}".format(path))

    # for param in encoder.parameters():
    #     param.requires_grad = False
    # for param in vq.parameters():
    #     param.requires_grad = False

    if params["train_batch_size"] == 0:
        dataset = dataset.to(device)
        if not isinstance(labels, list):
            labels = labels.to(device)

    logger = Logger()

    for idx, split in enumerate(splits):
        print(f"Split: {idx+1} / {len(splits)}")
        seed_everything(idx)

        if params["setting"] == "standard":
            split = split
        elif params["setting"] in ["few_shot", "zero_shot", "in_context"]:
            if task in ["node", "link"]:
                split = get_split(split, labels, params)
            elif task == "graph":
                split = get_split_graph(split, labels, params)
        else:
            raise ValueError("Invalid Setting")
        
        if task == "GQA":
            task_model = GQATaskModel(encoder=deepcopy(encoder),
                                      vq=deepcopy(vq),
                                      llm=GraphLLM(params),
                                      params=params,
            ).to(device)
        else:
            task_model = TaskModel(
                encoder=encoder,
                vq=vq,
                num_classes=num_classes,
                params=params,
            ).to(device)

        # opt_params = task_model.parameters()
        opt_params = [p for _, p in task_model.named_parameters() if p.requires_grad]
        task_opt = AdamW(opt_params, lr=params["lr"])
        stopper = EarlyStopping(patience=params["early_stop"])

        if task in ["node", "link"]:
            train_loader, subgraph_loader = get_loader(dataset, split, labels, params)
        elif task in ["graph", "GQA"]:
            train_loader, val_loader, test_loader = get_loader(dataset, split, labels, params)

        for epoch in tqdm(range(params["finetune_epochs"])):
            loss = finetune(
                model=task_model,
                dataset=dataset,
                loader=train_loader,
                optimizer=task_opt,
                split=split,
                num_classes=num_classes,
                labels=labels,
                params=params
            )

            if task == "GQA":
                os.makedirs("{}/{}".format(params['GQA_eval_path'], params['finetune_dataset']), exist_ok=True)
                eval_GQA_path = '{}/{}/llm_model_name_{}_num_epochs_{}_seed{}_'.format(params['GQA_eval_path'],params['finetune_dataset'],params["llm_model_name"],epoch,idx)
                
            result = evaluate(
                model=task_model,
                dataset=dataset,
                loader=subgraph_loader if task in ["node", "link"] else [train_loader, val_loader, test_loader],
                split=split,
                num_classes=num_classes,
                labels=labels,
                params=params,
                path=eval_GQA_path if task == "GQA" else None
            )

            is_stop = stopper(result)
            logger.log(idx, epoch, loss, result)
            if is_stop:
                print("Early Stopping at Epoch:", epoch)
                break

            if params['wandb'] is True:
                wandb.log({
                    "train/proto_loss": loss['proto_loss'],
                    "train/lin_loss": loss['act_loss'],
                    "train/loss": loss['loss'],
                    "train/train_value": result['train'],
                    "train/val_value": result['val'],
                    "train/test_value": result['test'],
                })

        single_best = logger.get_single_best(idx)

        if params['wandb'] is True:
            wandb.log({
                "best/train": single_best["train"],
                "best/val": single_best["val"],
                "best/test": single_best["test"],
            })

    best = logger.get_best()

    if params['wandb'] is True:
        wandb.log({
            "final/train": "{:.2f} ± {:.2f}".format(best['train']['mean'], best['train']['std']),
            "final/val": "{:.2f} ± {:.2f}".format(best['val']['mean'], best['val']['std']),
            "final/test": "{:.2f} ± {:.2f}".format(best['test']['mean'], best['test']['std']),
            "final/train_mean": best['train']['mean'],
            "final/val_mean": best['val']['mean'],
            "final/test_mean": best['test']['mean'],
            "final/train_std": best['train']['std'],
            "final/val_std": best['val']['std'],
            "final/test_std": best['test']['std'],
        })
        wandb.log({'meta/run': logger.get_run_raw(), 'meta/best': logger.get_best_raw()})

        wandb.finish()

    if not os.path.exists(f"best"):
        os.makedirs(f"best")
    with open(f"best/{dataset_name}.txt", "w") as f:
        f.write(f"Best value: {best['test']['mean']:.2f}±{best['test']['std']:.2f}\n")
