import os
import torch
import torchvision
import argparse
import numpy as np
import timm

import utils
import resnet
import wrn
import vgg

torch.manual_seed(0)   

def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--net','-n', default = 'resnet18', type=str)
    parser.add_argument('--data', '-d', type=str)
    parser.add_argument('--gpu', '-g', default = '0', type=str)
    parser.add_argument('--save_path', '-s', type=str)

    args = parser.parse_args()

    config = utils.read_conf('conf/'+args.data+'.json')
    device = 'cuda:'+args.gpu

    model_name = args.net
    dataset_path = config['id_dataset']
    save_path = config['save_path'] + args.save_path
    num_classes = int(config['num_classes'])
    class_range = list(range(0, num_classes))

    if args.net == 'resnet18':
        batch_size = int(config['batch_size'])
        max_epoch = int(config['epoch'])
        wd = 1e-03
        lrde = [100, 150, 190]
        lr = 0.1


    if args.net == 'wrn28':
        batch_size = int(config['batch_size'])
        max_epoch = 200
        wd = 1e-03
        lrde = [100, 150, 390]
        lr = 0.1
    if args.net == 'vgg11':
        batch_size = int(config['batch_size'])
        max_epoch = int(config['epoch'])
        wd = 1e-03
        lrde = [75, 90]
        lr = 0.05

    print(model_name, dataset_path.split('/')[-2], batch_size, class_range)
    
    if not os.path.exists(config['save_path']):
        os.mkdir(config['save_path'])
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    else:
        raise ValueError('save_path already exists')
    
    if 'cifar' in args.data:
        train_loader, valid_loader = utils.get_cifar(args.data, dataset_path, batch_size)

    if 'resnet18' == args.net:
        model = resnet.resnet18(num_classes = num_classes)
    if 'wrn28' == args.net:
        model = wrn.WideResNet(depth=28, widen_factor=10, num_classes=num_classes)
    if 'vgg11' == args.net:
        model = vgg.VGG(vgg_name = 'VGG11', num_classes = num_classes)
    model.to(device)
    
    criterion = torch.nn.CrossEntropyLoss()    
    optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum=0.9, weight_decay = wd)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lrde)
    alpha = 0.1

    saver = timm.utils.CheckpointSaver(model, optimizer, checkpoint_dir= save_path, max_history = 2)    

    for epoch in range(max_epoch):
        ## training
        model.train()
        total_loss = 0
        total = 0
        correct = 0
        head_start = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):

            # if batch_idx >= 10:  # Stop after 10 batches
            #     break
            model.train()

            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            # Determine the phase based on batch index
            # phase = batch_idx % 2

            if batch_idx % 50 == 3:
                phase = 2  # Every 50th iteration goes into Phase 2
            else:
                phase = batch_idx % 2

            if phase == 0:
                # Phase 1: Calculate CE loss without freezing any layers
                use_prompt = False
                model.freeze_parameters(freeze=False)  # Unfreeze all layers
                outputs = model(inputs, use_prompt=use_prompt)
                loss = criterion(outputs, targets)

            elif phase == 1:
                # Phase 2: Calculate CE loss with prompt embeddings, without freezing any layers
                use_prompt = True
                model.freeze_parameters(freeze=False)  # Unfreeze all layers
                pro_emb, outputs = model(inputs , use_prompt=use_prompt,return_embeddings=True)
                loss = criterion(outputs, targets)

            elif phase == 2:
            # elif batch_idx % 50 == 0:
                # Phase 3: Freeze all layers except prompt embedding, calculate noise, and update prompt embedding
                use_prompt = True
                model.freeze_except_prompt() # freeze model except prompt

                # Enable gradients for inputs
                adv_inputs = (inputs ).detach().requires_grad_(True)

                # PGD attack loop
                for _ in range(10):  # PGD steps
                    # Forward pass with adversarial inputs 
                    outputs = model(adv_inputs, use_prompt=False)
                    loss = criterion(outputs, targets)

                    # Compute gradients of the loss with respect to adv_inputs
                    grad = torch.autograd.grad(loss, adv_inputs, only_inputs=True)[0]

                    # Compute adversarial perturbation
                    perturbation = grad.sign()  
                    adv_inputs = adv_inputs + perturbation  # Update adversarial inputs
                    adv_inputs = torch.clamp(adv_inputs, 0, 1)  # Clip to valid range
                    adv_inputs = adv_inputs.detach().requires_grad_(True)  

                # Forward pass with final adversarial inputs
                outputs = model(adv_inputs, use_prompt=False)
                loss = -criterion(outputs, targets)

                # Backward pass to compute gradients for prompt embeddings
                loss.backward()

                # Calculate adversarial noise using gradients of the prompt embeddings multiply coeficent
                adversarial_noise = alpha * model.prompt_embeddings.grad.sign()  

                # Update prompt embeddings
                model.prompt_embeddings.data += adversarial_noise.mean(dim=0, keepdim=True)
  

            # Backpropagation for phases 0 and 1
            if phase != 2:
                loss.backward()
                optimizer.step()

            total_loss += loss
            total += targets.size(0)
            _, predicted = outputs.max(1)            
            correct += predicted.eq(targets).sum().item()            
            print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                        % (total_loss/(batch_idx+1), 100.*correct/total, correct, total), end = '')                       
        train_accuracy = correct/total
        train_avg_loss = total_loss/len(train_loader)
        print()
        ratio=0

        ## validation
        model.eval()
        total_loss = 0
        total = 0
        correct = 0
        valid_accuracy = utils.validation_accuracy(model, valid_loader, device)
        scheduler.step()

        saver.save_checkpoint(epoch, metric = valid_accuracy)
        print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy))
        print(scheduler.get_last_lr())
if __name__ =='__main__':
    train()