from __future__ import print_function
import os, time

import torch
import torch.nn as nn
import torch.nn.parallel
from torch_geometric.loader import DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler

from datasets.ZincFluor import get_ZincFluor

from models.GCN import *

from train.train_GNN import train_base
from train.validate_GNN import valid_base

from utils.config import *
from utils.common import hms_string

from utils.logger import logger
import copy

args = parse_args()
reproducibility(args.seed)
args.logger = logger(args)

best_acc = 0  # best test accuracy
many_best, med_best, few_best = 0, 0, 0
best_model = None

def train_stage2(args, model, trainloader, testloader, N_SAMPLES_PER_CLASS):
    global best_acc, many_best, med_best, few_best, best_model
    classwise_best_acc = [0,0,0]
    train_criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smooth)
    test_criterion = nn.CrossEntropyLoss()  # For test, validation
    optimizer = optim.SGD(list(model.lins[0].parameters()) + list(model.fc.parameters()), lr=args.finetune_lr, momentum=0.9, weight_decay=args.finetune_wd)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.finetune_epoch, eta_min=0.0)

    best_model = None
    test_accs = []
    start_time = time.time()
    for epoch in range(args.finetune_epoch):

        train_loss, train_acc = train_base(trainloader, model, optimizer, train_criterion,args=args)
        test_loss, test_acc, test_cls,classwise_acc = valid_base(testloader, model, test_criterion, N_SAMPLES_PER_CLASS,
                                                   num_class=args.num_class, mode='test Valid')

        lr = scheduler.get_last_lr()[0]
        scheduler.step()

        if sum(classwise_best_acc[-6:]) <= sum(classwise_acc[-6:]):
            classwise_best_acc = classwise_acc
            best_acc = test_acc
            many_best = test_cls[0]
            med_best = test_cls[1]
            few_best = test_cls[2]

            best_model = copy.deepcopy(model)
        test_accs.append(test_acc)

        args.logger(f'Epoch: [{epoch + 1} | {args.finetune_epoch}]', level=1)

        args.logger(f'[Train]\tLoss:\t{train_loss:.4f}\tAcc:\t{train_acc:}', level=2)
        args.logger(f'[Test ]\tLoss:\t{test_loss:.4f}\tAcc:\t{test_acc:.4f}', level=2)
        args.logger(f'[Stats]\tMany:\t{test_cls[0]:.4f}\tMedium:\t{test_cls[1]:.4f}\tFew:\t{test_cls[2]:.4f}', level=2)
        args.logger(
            f'[Best ]\tAcc:\t{np.max(test_accs):.4f}\tMany:\t{100 * many_best:.4f}\tMedium:\t{100 * med_best:.4f}\tFew:\t{100 * few_best:.4f}',
            level=2)
        args.logger(f'[Param]\tLR:\t{lr:.8f}', level=2)
        args.logger(f'[Classwise_acc]\:\t{classwise_acc}', level=2)
        args.logger(f'[Best][Classwise_acc]\:\t{classwise_best_acc}', level=2)

    end_time = time.time()

    file_name = os.path.join(args.out, 'best_model_stage2.pth')
    torch.save(best_model, file_name)

    # Print the final results
    args.logger(f'Finish Training Stage 1...', level=1)
    args.logger(f'Final performance...', level=1)
    args.logger(f'best bAcc (test):\t{np.max(test_accs)}', level=2)
    args.logger(f'best statistics:\tMany:\t{many_best}\tMed:\t{med_best}\tFew:\t{few_best}', level=2)
    args.logger(f'Training Time: {hms_string(end_time - start_time)}', level=1)

def load_model(args, model, testloader, N_SAMPLES_PER_CLASS):
    if args.pretrained_pth is not None:
        pth = args.pretrained_pth
    else:
        pth = f'pretrained/cifar100/IR={args.imb_ratio}/best_model_stage1.pt'

    model  = torch.load(pth, weights_only=False)
    # model.load_state_dict(state_dict)

    # model = torch.load(pth)
    test_criterion = nn.CrossEntropyLoss()  # For test, validation
    test_loss, test_acc, test_cls,classwise_acc = valid_base(testloader, model, test_criterion, N_SAMPLES_PER_CLASS,
                                             num_class=args.num_class, mode='test Valid')

    args.logger(f'Loaded performance...', level=1)
    args.logger(f'[Test ]\t Acc:\t{test_acc:.4f}', level=2)
    args.logger(f'[Stats]\tMany:\t{test_cls[0]:.4f}\tMedium:\t{test_cls[1]:.4f}\tFew:\t{test_cls[2]:.4f}', level=2)
    args.logger(f'[Classwise_acc]\:\t{classwise_acc}', level=2)

    return model

def main():
    print(f'==> Preparing ZINC_FLUOR')
    trainset, testset = get_ZincFluor()
    N_SAMPLES_PER_CLASS = trainset.img_num_list
    print(N_SAMPLES_PER_CLASS)
    
    trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
                                  drop_last=False, pin_memory=True, sampler=None)
    testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
                                 pin_memory=True)
    
    args.num_class = trainset.num_classes
    
    # Model
    print("==> creating GCN")
    model = model_gcn(args,num_features=trainset.num_features,num_classes = trainset.num_classes).cuda()

    model = load_model(args, model, testloader, N_SAMPLES_PER_CLASS)
    train_stage2(args, model, trainloader, testloader, N_SAMPLES_PER_CLASS)

if __name__ == '__main__':
    main()





# CUDA_VISIBLE_DEVICES=2 python main_stage2_gnn.py --cur_stage stage2 --SupCon True --label_smooth 0.98 --pretrained_pth /data/zz/LOC_LT/LOS/output/ZincFluor_IR=1000_stage1/5-11_0-55-24/best_model_stage1.pth
# CUDA_VISIBLE_DEVICES=2 python main_stage2_gnn.py --cur_stage stage2 --label_smooth 0.98 --pretrained_pth /data/zz/LOC_LT/LOS/output/ZincFluor_IR=1000_stage1/5-11_0-6-15/best_model_stage1.pth