import argparse
import os
import shutil
import time
import random

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models
import torch.nn.functional as F
from PIL import Image
import numpy as np


class WeightNet(nn.Module):
    def __init__(self):
        super(WeightNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2048, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid() 
        )

    def forward(self, x):
        return self.net(x).squeeze()

model_names = sorted(name for name in torchvision.models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(torchvision.models.__dict__[name]))
model_names.append('resnet50')

parser = argparse.ArgumentParser(description='Training with Weight-Net')
parser.add_argument('-d', '--dataset', default='caltech101', type=str)
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N')
parser.add_argument('--data_dir', default='data/cifar100_original_data', type=str, metavar='PATH')
parser.add_argument('--num_classes', type=int, required=True)
parser.add_argument('--epochs', default=40, type=int, metavar='N')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N')
parser.add_argument('--train-batch', default=64, type=int, metavar='N')
parser.add_argument('--test-batch', default=100, type=int, metavar='N')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W')
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH')
parser.add_argument('--resume', default='', type=str, metavar='PATH')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names)
parser.add_argument('--manualSeed', type=int)
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true')
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--meta-lr', type=float, default=1e-3)

args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}

class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].contiguous().reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def mkdir_p(path):
    try:
        os.makedirs(path)
    except OSError as exc:
        if exc.errno == os.errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise

def save_accuracies(epoch, train_acc, test_acc, best_acc, is_best=False):
    file_path = os.path.join(args.checkpoint, 'accuracies_weight_net.txt')
    if not os.path.exists(args.checkpoint):
        mkdir_p(args.checkpoint)
    with open(file_path, 'a') as f:
        if is_best:
            f.write(f'Epoch {epoch}: Train Acc = {train_acc:.4f}, Test Acc = {test_acc:.4f}, Best Acc = {best_acc:.4f} (Current best accuracy)\n')
        else:
            f.write(f'Epoch {epoch}: Train Acc = {train_acc:.4f}, Test Acc = {test_acc:.4f}, Best Acc = {best_acc:.4f}\n')

def save_checkpoint(state, is_best, best_acc, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def test(val_loader, model, criterion, epoch, use_cuda):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()

    end = time.time()
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        with torch.no_grad():
            outputs = model(inputs)
        loss = criterion(outputs, targets)

        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % 10 == 0:
            print('Test: [{0}/{1}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})\tLoss {loss.val:.4f} ({loss.avg:.4f})\tPrec@1 {top1.val:.3f} ({top1.avg:.3f})\tPrec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(batch_idx, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5))

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))

    return (losses.avg, top1.avg)

def main():
    global best_acc
    best_acc = 0
    start_epoch = args.start_epoch

    if args.manualSeed is not None:
        random.seed(args.manualSeed)
        torch.manual_seed(args.manualSeed)
        cudnn.deterministic = True

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
        cudnn.benchmark = True

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    if not os.path.exists(args.data_dir):
        raise FileNotFoundError(f"Data directory does not exist: {args.data_dir}")
    train_dir = os.path.join(args.data_dir, 'train')
    val_dir = os.path.join(args.data_dir, 'valid')
    test_dir = os.path.join(args.data_dir, 'test')
    for dir_path in [train_dir, val_dir, test_dir]:
        if not os.path.exists(dir_path):
            raise FileNotFoundError(f"Dataset subdirectory does not exist: {dir_path}")

    transform_train = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomRotation(15),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
    ])
    transform_test = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
    ])

    full_trainset = datasets.ImageFolder(train_dir, transform_train)
    valset = datasets.ImageFolder(val_dir, transform_test)
    testset = datasets.ImageFolder(test_dir, transform_test)

    trainloader = data.DataLoader(
        full_trainset, batch_size=args.train_batch, shuffle=True,
        num_workers=args.workers, pin_memory=use_cuda
    )
    valloader = data.DataLoader(
        valset, batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=use_cuda
    )
    testloader = data.DataLoader(
        testset, batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=use_cuda
    )

    model = torchvision.models.resnet50(weights=None)
    model.fc = nn.Linear(2048, args.num_classes)
    if use_cuda:
        model = torch.nn.DataParallel(model).cuda()

    weight_net = WeightNet()
    if use_cuda:
        weight_net = weight_net.cuda()
    optimizer_vnet = optim.Adam(weight_net.parameters(), lr=args.meta_lr, weight_decay=1e-4)

    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
    feature_extractor = nn.Sequential(*list(model.module.children())[:-1])
    if use_cuda:
        feature_extractor = feature_extractor.cuda()

    val_iter = iter(valloader)

    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))

        model.train()
        weight_net.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        end = time.time()

        for batch_idx, (inputs, targets) in enumerate(trainloader):
            data_time.update(time.time() - end)
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            def forward():
                features = feature_extractor(inputs).squeeze()
                outputs = model(inputs)
                loss_per_sample = criterion(outputs, targets)
                with torch.no_grad():
                    weights = weight_net(features.detach())
                weighted_loss = (weights * loss_per_sample).mean()
                return weighted_loss, outputs

            weighted_loss, outputs = forward()
            optimizer.zero_grad()
            weighted_loss.backward()
            optimizer.step()

            try:
                val_inputs, val_targets = next(val_iter)
            except StopIteration:
                val_iter = iter(valloader)
                val_inputs, val_targets = next(val_iter)
            if use_cuda:
                val_inputs, val_targets = val_inputs.cuda(), val_targets.cuda()

            val_outputs = model(val_inputs)
            meta_loss = F.cross_entropy(val_outputs, val_targets)

            optimizer_vnet.zero_grad()
            meta_loss.backward()
            optimizer_vnet.step()

            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(weighted_loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % 10 == 0:
                print('Epoch: [{0}][{1}/{2}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})\tData {data_time.val:.3f} ({data_time.avg:.3f})\tLoss {loss.val:.4f} ({loss.avg:.4f})\tPrec@1 {top1.val:.3f} ({top1.avg:.3f})\tPrec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, batch_idx, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5))

        train_acc = top1.avg

        test_loss, test_acc = test(testloader, model, nn.CrossEntropyLoss(), epoch, use_cuda)

        is_best = test_acc > best_acc
        if is_best:
            best_acc = test_acc
            print(f"New best accuracy achieved at epoch {epoch + 1}: {best_acc:.4f}")

        save_accuracies(epoch + 1, train_acc, test_acc, best_acc, is_best)

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'weight_net': weight_net.state_dict(),
            'acc': test_acc,
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
            'meta_optimizer': optimizer_vnet.state_dict()
        }, is_best, best_acc, checkpoint=args.checkpoint)

        if is_best:
            torch.save(weight_net.state_dict(), os.path.join(args.checkpoint, 'weight_net_best.pth.tar'))

        scheduler.step()

if __name__ == '__main__':
    main()