import torch
import torch.nn.functional as F
from models.GCN3 import GCN3
from models.GCN2 import GCN2
from models.GCN1 import GCN1
import os
import wandb
import argparse
from eval_utils import *
from dataset import *
from torch_geometric.data import DataLoader, Data
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from relplot.metrics import smECE_slow as smece
from sklearn.metrics import roc_auc_score, brier_score_loss
eval_freq = 1

def main(args):
    isval=args.is_val
    device_num = args.device
    device_name = 'cuda:'+str(device_num)
    data_name = args.data_name
    llm_name = args.llm_name
    bsz = args.bsz
    model_name = args.model_name
    epochs = args.epochs
    lr = args.lr
    num_samples = args.num_samples
    p_list = [args.p0,args.p1,args.p2]
    split_size = args.split_size
    
    if torch.cuda.is_available():
        device = torch.device(device_name)
        print("CUDA is available. Using GPU.")
    else:
        device = torch.device("cpu")
        print("CUDA is not available. Using CPU.")
    
    with open(f'data/{llm_name}/{data_name}/{llm_name}_generations_{data_name}-df.pkl','rb') as f: #
        df = pkl.load(f)
    hidden_channel, hidden_channel2, hidden_channel3 = args.hidden_channel,args.hidden_channel2,args.hidden_channel3 

    if model_name == 'GCN2':
        name=f'{num_samples}samples_{data_name}_{model_name}_{hidden_channel}_{hidden_channel2}_{epochs}_{bsz}_lr_{lr}'
        cfg = {'bsz':bsz,'epoch':epochs,'model_name':model_name,'hidden_channel':hidden_channel, 
               'hidden_channel2':hidden_channel2}
    elif model_name == 'GCN3':
        name=f'neweval_{num_samples}samples_{data_name}_{model_name}_{hidden_channel}_{hidden_channel2}_{hidden_channel3}_{epochs}_{bsz}_lr_{lr}_p0_{p_list[0]}_p1_{p_list[1]}_p2_{p_list[2]}'
        cfg = {'bsz':bsz,'epoch':epochs,'model_name':model_name,'hidden_channel':hidden_channel, 
               'hidden_channel2':hidden_channel2,'hidden_channel3':hidden_channel3,'p0':p_list[0],'p1':p_list[1],'p2':p_list[2]}
    elif model_name == 'GCN1':
        name=f'{num_samples}samples_{data_name}_{model_name}_{hidden_channel}_{epochs}_{bsz}_isval_{isval}_lr_{lr}'
        cfg = {'bsz':bsz,'epoch':epochs,'model_name':model_name,'hidden_channel':hidden_channel}
        
    dirpath=f'gnn_res/{llm_name}/{data_name}/{num_samples}samples'
    wandb.init(config=cfg,settings=wandb.Settings(start_method="fork"),
                   project=f'gnn_uq_{llm_name}_{data_name}',
                   name=name,
                   dir=dirpath,
                   job_type="training",
                   reinit=True)
    directory_path = os.path.join(dirpath,f'{llm_name}_'+name)
    
    dataset = MyCustomDataset(root=f'data/{llm_name}/{data_name}/{llm_name}_{data_name}_graphs.pkl')
    dataset_size = len(dataset)
    print('Total dataset size is: ', dataset_size)
    train_size = int(dataset_size * split_size) 
    val_size = dataset_size - train_size 
    print('Train dataset size is: ', train_size, 'Val dataset size is: ', val_size)
    train_dataset, val_dataset = dataset[:train_size], dataset[-val_size:]

    train_loader = DataLoader(train_dataset, batch_size=bsz, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=bsz, shuffle=False)

    if model_name == 'GCN1':
        gnnmodel = GCN1(num_node_features=dataset[0].x.size(-1), num_classes=1,hidden_channels=hidden_channel)
    elif model_name == 'GCN2':
        gnnmodel = GCN2(num_node_features=dataset[0].x.size(-1), num_classes=1,hidden_channels=hidden_channel,
                        hidden_channels2=hidden_channel2)
    elif model_name == 'GCN3':
        gnnmodel = GCN3(num_node_features=dataset[0].x.size(-1), num_classes=1,hidden_channels=hidden_channel, 
                        hidden_channels2=hidden_channel2,hidden_channels3=hidden_channel3)
    gnnmodel = gnnmodel.to(device)
    if isval:
        gnnmodel.load_state_dict(torch.load('gnn.pth'))#,map_location='cpu'))#
        gnnmodel.eval()
        ret = eval_val(val_size, train_size, df,gnnmodel,val_dataset,device,directory_path=directory_path)  
    else:
        optimizer = optim.Adam(gnnmodel.parameters(), lr=lr)
        scheduler = StepLR(optimizer, step_size=10, gamma=0.95)

        loss_func = torch.nn.BCEWithLogitsLoss()
        train_loss_list, val_loss_list = [], []

        if not os.path.exists(directory_path):
            os.mkdir(directory_path)

        for epoch in range(epochs):
            train_loss = 0
            gnnmodel.train()
            for idx, batch in enumerate(train_loader):
                batch=batch.to(device)
                optimizer.zero_grad()
                batch.x=batch.x.to(torch.float32)
                out = gnnmodel(batch,p_list)
                loss = loss_func(out.squeeze(), batch.y.to(torch.float))
                
                loss.backward()
                optimizer.step()
                train_loss += loss.item() 
            scheduler.step()
            current_lr = optimizer.param_groups[0]['lr']
            avg_trainloss = train_loss / (len(train_loader.dataset))
            train_loss_list.append(avg_trainloss)
            
            if epoch%eval_freq == 0:
                gnnmodel.eval()
                val_loss = 0
                ysece,ytece=[],[]
                with torch.no_grad():
                    for idx, batch in enumerate(val_loader):
                        batch=batch.to(device)
                        batch.x=batch.x.to(torch.float32)
                        out = gnnmodel(batch)
                        loss = loss_func(out.squeeze(), batch.y.to(torch.float))
                        
                        val_loss += loss.item()
                        score = torch.sigmoid(out)

                        ylabel = batch.y.reshape(-1,30)
                        ylabel = ylabel[rows, cols]
                        labellist=ylabel.tolist()
                        
                        ysece.extend(scorelist)
                        ytece.extend(labellist)
                ret=dict()
                ret['auroc'] = roc_auc_score(y_true=ytece, y_score=ysece)
                ret['smece'] = smece(f=np.array(ysece),y=np.array(ytece))
                ret['bier_score'] = brier_score_loss(y_true=ytece, y_prob=ysece)
                        
                avg_valloss = val_loss / (len(val_loader.dataset))
                val_loss_list.append(avg_valloss)
                y_true = np.array(y_true)
                y_pred = np.array(y_pred)
                prob_true, prob_pred,weight = calibration_curve(y_true, y_pred, n_bins=10)
                ece = np.sum(weight*(np.abs(prob_true - prob_pred)))/np.sum(weight)
                ret['ece'] = ece
                print(ret['auroc'],ret['smece'],ret['bier_score'])

            wandb.log({'train_loss':avg_trainloss,'val_loss':avg_valloss,'ece':ece,'tauroc':ret['tauroc'],'auroc':ret['auroc'],'smece':ret['smece'],'brier_score':ret['bier_score']})
            if epoch>200:
                torch.save(gnnmodel.state_dict(), os.path.join(directory_path,
                            f'{llm_name}_{model_name}_{hidden_channel}_{hidden_channel2}_{hidden_channel3}_{epoch}_{bsz}.pth'))

            print(f'Epoch {epoch+1}, Train Loss: {avg_trainloss}, Val Loss: {avg_valloss}, Lr:{current_lr},ece:{ece}')
    wandb.finish()
 

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=".")
    parser.add_argument("--llm_name", help="",default='llama3')
    parser.add_argument("--model_name", help="",default='GCN3')
    parser.add_argument("--data_name", help="",default='truthfulqa')
    parser.add_argument("--is_val",  action='store_true',help="") 
    parser.add_argument("--device", help="",default='0')
    parser.add_argument("--bsz", type=int,help="",default='32')
    parser.add_argument("--split_size", type=int,help="",default='0.8')
    parser.add_argument("--epochs", type=int,help="",default='600')
    parser.add_argument("--hidden_channel", type=int,help="",default='512')
    parser.add_argument("--hidden_channel2", type=int,help="",default='1024')
    parser.add_argument("--hidden_channel3", type=int,help="",default='2048')
    parser.add_argument("--lr", type=float,help="",default=1e-4)
    parser.add_argument("--p0", type=float,help="",default=0)
    parser.add_argument("--p1", type=float,help="",default=0)
    parser.add_argument("--p2", type=float,help="",default=0)
    args = parser.parse_args()
    main(args)
