from unittest import loader
import numpy as np
import scipy.sparse as sp
from sklearn.metrics import f1_score
import random
import time
from models import LogReg
#from preprompt import PrePrompt,pca_compression
import preprompt as preprompt
from utils import process
import pdb
import aug
import os
import tqdm
import argparse
from downprompt import downprompt
import csv
from tqdm import tqdm
parser = argparse.ArgumentParser("MDGMIX")
import torch.nn.functional as F
from models.gcn import GCNPYG
from preprompt import MDGMIX
from preprompt import pca_compression
import copy 
parser.add_argument('--dataset', type=str, default="Cora", help='data')
parser.add_argument('--pretrain_datasets', nargs='+', type=str, 
    help='pretrain datasets', default=['Cora', 'Photo','Citeseer', 'Pubmed','Computers','chameleon', 'squirrel'])

parser.add_argument('--drop_percent', type=float, default=0.5, help='drop percent')

parser.add_argument('--lr', type=float, default=0.02, help='pretrain lr')
parser.add_argument('--downstreamlr', type=float, default=0.003, help='downstream lr')
parser.add_argument('--epochs', type=int, default=2, help='epoch')
parser.add_argument('--shot_num', type=int, default=3, help='shotnum')
parser.add_argument('--skip_pretrain', type=int, default=0, help='try to use trained models')

parser.add_argument('--seed', type=int, default=39, help='seed')
parser.add_argument('--gpu', type=int, default=0, help='gpu')
parser.add_argument('--save_name', type=str, default='model_add_node_lay3_computers.pkl', help='save ckpt name')
parser.add_argument('--val_name', type=str, default='noval_graphcl_BZR.pkl', help='save val')
parser.add_argument('--combinetype', type=str, default='mul', help='the type of text combining')
args = parser.parse_args()

print(args)
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 
seed = args.seed
random.seed(seed)
np.random.seed(seed)

import torch
import torch.nn as nn
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
from torch_geometric.datasets import TUDataset,Planetoid,Amazon,Coauthor,Reddit,Actor,WikipediaNetwork,WebKB,Flickr
from torch_geometric.loader import DataLoader
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from datasets import *

from utils.BoundaryMixup import *


from utils.data_utils import preprocess, iterate_datasets
from torch_geometric.utils import normalize_edge_index
from utils.data_utils import x_svd
boundary_ratio_list = np.linspace(0.1, 0.9, 9).tolist()
similarity_threshold_list = np.linspace(0.1, 0.9, 9).tolist()
boundary_ratio_list = [0.1]
similarity_threshold_list = [0.3]
for boundary_ratio in boundary_ratio_list:
    for similarity_threshold  in similarity_threshold_list:

        print('-' * 100)
        batch_size = 128
        nb_epochs = args.epochs
        args.pretrain_datasets = [item for item in args.pretrain_datasets if item not in [args.dataset]]
        print(args.pretrain_datasets)
        num_pretrain_dataset = len(args.pretrain_datasets)
        patience = 50
        lr_list=args.lr
        l2_coef = 1e-4#1e-4 #0.0001
        drop_prob = 0.5
        hid_units = 256
        sparse = True
        useMLP =False
        LP = False
        shot_num=args.shot_num
        downstreamlrlist = args.downstreamlr
        nonlinearity = 'prelu' 
        dataset = args.dataset
        device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")

        print(f"Using device: {device}")
        best = 1e9
        firstbest = 0

        root = '/datasets'
        args.save_name = f'model_{args.dataset}.pkl'

        data_list = {i: preprocess(data).to(device) for i, data in enumerate(iterate_datasets(args.pretrain_datasets))}

    
        mixup = CrossDomainMixup(num_domains=num_pretrain_dataset,  alpha=0.2, hop=1,
                                 boundary_ratio=boundary_ratio, similarity_threshold=similarity_threshold)
        

        for i, data in data_list.items():
            data.x = F.normalize(data.x, dim=1, p=2)


        positive_pairs = mixup.generate_positive_pairs(data_list, nodes_per_domain=200, pairs_per_domain=5)#5
        print(f"Generated {len(positive_pairs)} positive pairs")
        
        
        negative_pairs = mixup.generate_negative_pairs(data_list, num_samples=10)#10
        print(f"Generated {len(negative_pairs)} negative pairs")
        subgraph_list = positive_pairs + negative_pairs




        cnt_wait = 0
        b_xent = nn.BCEWithLogitsLoss()
        xent = nn.CrossEntropyLoss()
        unify_dim = 50
        a=args.save_name
        n_=0

        if args.skip_pretrain == 0:
            for lr in [lr_list]:
                #time_=time.localtime()
                n_+=1
                best = 1e9
                firstbest = 0


                if torch.cuda.is_available():
                    for subgraph in subgraph_list:
                        subgraph = subgraph.to(device)
 
                labels = []
                domain_labels = []
                for subgraph in subgraph_list:
                    label = subgraph.y
                    labels.append(label.reshape(1,-1))

                    domain_lb = subgraph.is_cross
                    domain_labels.append(domain_lb.reshape(1,-1))
                labels = torch.cat(labels, dim=0).to(device)

                domain_labels = torch.cat(domain_labels, dim=0).to(device)

                domain_indices = torch.sum(labels, dim=0)
           

                domain_centers = []
                for i, data in data_list.items():
                    edge_index = data.edge_index
                    x = data.x

                    domain_center = x.mean(0).reshape(1,-1)
                    if domain_indices[i]>0:
                        domain_centers.append(domain_center)
                domain_centers = torch.cat(domain_centers, dim=0)

      

                model = MDGMIX(input_dim = unify_dim, hidden_dim = hid_units, num_domains = num_pretrain_dataset, domain_center=domain_centers, device=device).to(device)

                optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)
                
          
                num_graph = labels.shape[0]
            
                mask = torch.randperm(num_graph)
                labels= labels[mask]
                domain_labels = domain_labels[mask]
                merge_graph_list = [subgraph_list[i] for i in mask]  


                epoch_bar = tqdm(range(nb_epochs), desc='Training', unit='epoch')
                epoch_start = time.time()

                for epoch in epoch_bar:


                    torch.cuda.reset_peak_memory_stats()      
                    torch.cuda.empty_cache()                  
                    np.random.seed(seed)
                    torch.manual_seed(seed)
                    torch.cuda.manual_seed(seed)

                

                    loss = 0
                    regloss = 0
                    model.train()
                    optimiser.zero_grad()

                    
                    loss = model(merge_graph_list, labels, domain_labels)
                    loss.backward()
                    optimiser.step()

                    epoch_bar.set_postfix({
                        'Loss': f'{loss.item():.4f}',
                        'Best_Loss': f'{best:.4f}',
                    })
                   

                    if loss < best:
                        firstbest = 1
                        best = loss
                        best_t = epoch
                        cnt_wait = 0
                        torch.save(model.state_dict(), args.save_name)
                    else:
                        cnt_wait += 1
                    if cnt_wait == patience:
                        print('Early stopping!')
                        break
                    
        model = MDGMIX(input_dim = unify_dim, hidden_dim = hid_units, num_domains = num_pretrain_dataset,  domain_center=domain_centers, device=device).to(device)

        print('#'*50)
        print('Downastream dataset is ',args.dataset)

        print(args.dataset)
        
        
        downstream_dataset = load_dataset(args.dataset)
        loader = DataLoader(downstream_dataset)
        for data in loader:
            print(data)
            features,adj= process.process_tu(data,data.x.shape[1])
            features = pca_compression(features,k=unify_dim)
            adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))
            sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
            sp_adj = sp_adj.to(device)
            features = F.normalize(torch.FloatTensor(features),p=2,dim=1).to(device)
            data.x = features
            print(features.shape)
            data = data.to(device)
            labels = data.y
            nb_classes=len(np.unique(np.array(data.y.cpu())))
            print(nb_classes)
        
       

        labels = data.y
        nb_classes=len(np.unique(np.array(data.y.cpu())))

        model = model.to(device)

        model.load_state_dict(torch.load(args.save_name))

        acclist = torch.FloatTensor(100,).to(device)

 
        for downstreamlr in [downstreamlrlist]:
            
            print(labels.shape)
            tot = torch.zeros(1)
            tot = tot.to(device)
            accs = []
            macrof = []
            microf = []
            print('-' * 100)
            for shotnum in range(shot_num,shot_num+1):
                tot = torch.zeros(1)
                tot = tot.to(device)
                accs = []
                cnt_wait = 0
                best = 1e9
                best_t = 0
                print("shotnum",shotnum)
                root ='/datasets'
                for i in tqdm(range(100)):

                    domain_token = model.domain_tokens
            
                    num_domains =len(domain_token)

                    log = downprompt(features, num_domains, hid_units, nb_classes,unify_dim).to(device)
                    
      
                    idx_train = torch.load("/fewshot_dataset/fewshot_{}_node/{}-shot_{}/trainset/{}/train_index.pt".format(
                                args.dataset.lower(), shotnum, args.dataset.lower(), i)).type(torch.long).to(device)
                    
                    train_lbls = torch.load("/fewshot_dataset/fewshot_{}_node/{}-shot_{}/trainset/{}/train_labels.pt".format(
                                args.dataset.lower(), shotnum, args.dataset.lower(), i)).type(torch.long).to(device)
                    
                    idx_test = torch.load("/fewshot_dataset/fewshot_{}_node/{}-shot_{}/testset/{}/test_index.pt".format(
                                args.dataset.lower(), shotnum, args.dataset.lower(), i)).type(torch.long).to(device)
                    
                    test_lbls = torch.load("/fewshot_dataset/fewshot_{}_node/{}-shot_{}/testset/{}/test_labels.pt".format(
                                args.dataset.lower(), shotnum, args.dataset.lower(), i)).type(torch.long).to(device)
                    
                

                    opt = torch.optim.Adam([{'params': log.parameters()}], lr=downstreamlr)
                    best = 1e9
                    pat_steps = 0
                    best_acc = torch.zeros(1)
                    patience = 50

                    best_acc = best_acc.to(device)
                    for _ in range(400):
                        log.train()
                        opt.zero_grad()
                        logits = log(features, sp_adj, data,domain_token, model.gnn, idx_train, train_lbls,1).float().to(device)
                        loss = xent(logits, train_lbls)
                        if loss < best:
                            best = loss
                            cnt_wait = 0
                        else:
                            cnt_wait += 1
                        if cnt_wait == patience:
                            print('Early stopping!')
                            break
                        
                        loss.backward()
                        opt.step()
                    logits = log(features, sp_adj,data,domain_token, model.gnn, idx_test)
                    preds = torch.argmax(logits, dim=1).to(device)
                
                    acc = 0.0
                    preds_cpu = preds.cpu().numpy()
                    test_lbls_cpu = test_lbls.cpu().numpy()
                    micro_f1 = f1_score(test_lbls_cpu, preds_cpu, average='micro')
                    macro_f1 = f1_score(test_lbls_cpu, preds_cpu, average='macro')
                    microf.append(micro_f1 * 100)
                    macrof.append(macro_f1 * 100)
                    accs.append(acc * 100)
                    print('Average Micro:[{:.4f}]'.format(micro_f1))
                    tot += acc
                print("-" * 100)
                print("Average accuracy:[{:.4f}]".format(tot.item() / 100))
       
                microf_mean = sum(microf) / len(microf)
                macrof_mean = sum(macrof) / len(macrof)
                microf_std = torch.std(torch.tensor(microf)).item()
                macrof_std = torch.std(torch.tensor(macrof)).item() 
                #print('ACC:{:.2f}±{:.2f}'.format(mean_acc,std_acc))
                print('Micro:{:.2f}±{:.2f}'.format(microf_mean,microf_std))
                print('Macro:{:.2f}±{:.2f}'.format(macrof_mean,macrof_std))
