import os
import sys
import random
import argparse
import numpy as np
import torch.nn.functional as F
import torch
from utils.config import _C as cfg
from utils.lnl_methods import *
from model import *
from model import ViT16B, ResNet50, ConvNeXtB
from utils.bmm import draw_wrong_event
import line_profiler
import copy
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, default="", help="path to config file")
parser.add_argument("--gpuid", default=None)
parser.add_argument("--noise_mode", default=None)
parser.add_argument("--noise_ratio", default=None)
parser.add_argument("--backbone", default=None)

args = parser.parse_args()

cfg.defrost()
cfg.merge_from_file(args.cfg)
if args.noise_mode is not None:
    cfg.noise_mode = args.noise_mode
if args.noise_ratio is not None:
    cfg.noise_ratio = float(args.noise_ratio)
if args.gpuid is not None:
    cfg.gpuid = int(args.gpuid)
if args.backbone is not None:
    cfg.backbone = args.backbone

def set_seed():
    torch.cuda.set_device(cfg.gpuid)
    seed = cfg.seed
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed()

# Train
def train(epoch, dataloader, wrong_event):
    model.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    correct, total = 0, 0
    current_prediction = np.full(dataloader.dataset.__len__(), -1)
    global save_basemodel
    global prev_prediction
    global prev_prediction_change
    global prev_model
    for batch_idx, (inputs, targets, index) in enumerate(dataloader):
        optimizer.zero_grad()
        
        inputs, targets = inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True)
        
        features, logits = model(inputs)
        prediction = F.log_softmax(logits, dim=1)
        loss = F.nll_loss(prediction, targets, reduction='none')
        _, predicted = torch.max(prediction.data, 1)

        
        acc = (predicted == targets).cpu()
        wrong_event[index[~acc]] += 1
        current_prediction[index] = predicted.cpu().numpy()
        loss = loss.mean()
        loss.backward()
        optimizer.step()
        
        total += targets.size(0)
        correct += predicted.eq(targets).cpu().sum().item()

        sys.stdout.write('\r')
        sys.stdout.write('Epoch [%3d/%3d] Iter[%3d/%3d]\t total-loss: %.4f' 
                         %( epoch, cfg.epochs, batch_idx+1, num_iter, loss.item()))
        sys.stdout.flush()
    print("\n| Train Epoch #%d\t Accuracy: %.2f\n" %(epoch, 100. * correct/total))

    # label wave
    if epoch != 1:

        prediction_change = np.sum(current_prediction != prev_prediction)
        sys.stdout.write('\n')
        sys.stdout.write('last epoch prediction change:%6d \ncurrent epoch prediction change:%6d \n' 
                         %(prev_prediction_change, prediction_change))
        sys.stdout.flush()
        
        if not save_basemodel:
            if prediction_change < prev_prediction_change and epoch != cfg.epochs:
                prev_prediction_change = prediction_change
                prev_prediction = current_prediction
                prev_model = model.state_dict()
            else:
                if prediction_change < prev_prediction_change:
                    prev_model = model.state_dict()
                if cfg.pretrained == True:
                    torch.save(prev_model, "./base_model/{}_pretrained_{}_{}_{}.pt".format(cfg.backbone, cfg.dataset, cfg.noise_mode, cfg.noise_ratio))
                else:
                    torch.save(prev_model, "./base_model/{}_scratch_{}_{}_{}.pt".format(cfg.backbone, cfg.dataset, cfg.noise_mode, cfg.noise_ratio))
                save_basemodel = True
                sys.stdout.write('\n')
                sys.stdout.write('Epoch [%3d] saves last epoch\' model as the base model by calculating prediction change.'%(epoch))
                sys.stdout.flush()
                
    elif epoch == 1:
        prev_prediction = current_prediction
        prev_model = model.state_dict()
        
    return 100. * correct/total, wrong_event

# Test
def test(epoch, dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            features, logits = model(inputs)
            prediction = F.log_softmax(logits, dim=1)
            loss = F.nll_loss(prediction, targets, reduction='none')
            _, predicted = torch.max(prediction.data, 1)

            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()

    acc = 100. * correct / total

    print("\n| Test Epoch #%d\t Accuracy: %.2f\n" %(epoch, acc))
    return acc

# ======== Data ========
if cfg.dataset == "clothing1m":
    from dataloader import dataloader_clothing1M as dataloader
    train_loader, _, test_loader = dataloader.build_loader(cfg)
elif cfg.dataset == "webvision":
    from dataloader import dataloader_webvision as dataloader
    train_loader, _, test_loader = dataloader.build_loader(cfg)
elif cfg.dataset.startswith("cifar"):
    from dataloader import dataloader_cifar as dataloader
    loader = dataloader.cifar_dataloader(cfg.dataset, noise_mode=cfg.noise_mode, noise_ratio=cfg.noise_ratio,\
                                        batch_size=cfg.batch_size, num_workers=cfg.num_workers, root_dir=cfg.data_path, model=cfg.model, stage=cfg.stage, pretrained=cfg.pretrained)
    train_loader = loader.run('train')
    test_loader = loader.run('test')
elif cfg.dataset == "tiny_imagenet":
    print("Loading Tiny-ImageNet...")
    from dataloader import dataloader_tiny_imagenet as dataloader
    train_loader, _, test_loader = dataloader.build_loader(cfg)

num_class = cfg.num_class

# ======== Model ========
if cfg.backbone == "ViT-B":
    model = ViT16B.ViTBase16(cfg)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=1e-5)
elif cfg.backbone == 'resnet50':
    model = ResNet50.ResNet50(cfg)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
elif cfg.backbone == 'convnext-B':
    model = ConvNeXtB.ConvNeXtB(cfg)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
else:
    raise ValueError('please check the backbone of the model')

model.cuda()

wrong_event = np.zeros(shape=train_loader.dataset.__len__(), dtype = np.int32)
save_basemodel = False
prev_prediction_change = train_loader.dataset.__len__()
prev_prediction = None
prev_model = None

print(f'len of wrong_event:{len(wrong_event)}')
print(f'image shape:{train_loader.dataset[0][0].shape}')


for epoch in range(1, cfg.epochs + 1):
    
    train_acc, wrong_event = train(epoch, train_loader, wrong_event)
    test_acc = test(epoch, test_loader)
    # if cfg.dataset == "webvision":
    #     imagenet_acc = test(epoch, imagenet_loader)
    if cfg.pretrained == False:
        scheduler.step(train_acc)
    if epoch == cfg.epochs:
        if cfg.pretrained == True:
            torch.save(wrong_event, "./stage1/{}_pretrained_{}_{}_{}_wrongevent.pt".format(cfg.backbone, cfg.dataset, cfg.noise_mode, cfg.noise_ratio))
        else:
            torch.save(wrong_event, "./stage1/{}_scratch_{}_{}_{}_wrongevent.pt".format(cfg.backbone, cfg.dataset, cfg.noise_mode, cfg.noise_ratio))
        draw_wrong_event(cfg, wrong_event)