from __future__ import print_function
import argparse
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_butterfly
from torch_butterfly.complex_utils import view_as_real, view_as_complex
import k_operation as kop
import numpy as np
import os
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from auto_augment import AutoAugment, Cutout
from archive import autoaug_paper_cifar10
from FastAutoAugment.data import Augmentation

def train(args, model, device, train_loader, optimizer, arch_optimizer, epoch,
        epoch_fnorms, epoch_gradnorms):
    model.train()
    
    epoch_fnorms.append(model.get_fnorms_butterfly())
    epoch_gradnorms.append(model.get_fnorms_butterfly(grad=True))

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        if arch_optimizer:
            arch_optimizer.zero_grad()

        # TODO regularization is only implemented for simultaneous opt
        output = model(data)
        loss = F.nll_loss(output, target)
        if args.lam > 0:
            reg = model.kop_reg() / (2 * data.shape[0])
            loss += (args.lam * reg)
        loss.backward()

        epoch_fnorms.append(model.get_fnorms_butterfly())
        epoch_gradnorms.append(model.get_fnorms_butterfly(grad=True))

        #print(model.C1.C[2].w.weight.grad)
        #print(model.C1.w.weight.grad)
        #print(model.C1.w.weight.grad)
        #print(model.C1.w.weight)
        #print(epoch_fnorms[-1])
        #quit()

        optimizer.step()

        if arch_optimizer and not args.fixed:
            arch_optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break

def train_alt(args, model, device, train_loader_sw, train_loader_arch, 
        optimizer, arch_optimizer, epoch, epoch_fnorms, epoch_gradnorms):
    model.train()
    
    epoch_fnorms.append(model.get_fnorms_butterfly())
    epoch_gradnorms.append(model.get_fnorms_butterfly(grad=True))

    train_iter_sw = iter(train_loader_sw)
    train_iter_arch = iter(train_loader_arch)

    # Filter architectural parameters
    arch_params, model_params = get_a_w_params(model)

    for batch_idx in range(len(train_iter_sw)):

        data_sw, targets_sw = next(train_iter_sw)
        data_arch, targets_arch = next(train_iter_arch)

        data_sw, targets_sw = data_sw.to(device), targets_sw.to(device)
        data_arch, targets_arch = data_arch.to(device), targets_arch.to(device)

        if args.bilevel:

            # Take a step with the architecture optimizer
            optimizer.zero_grad()
            if arch_optimizer:
                arch_optimizer.zero_grad()

                # if second order DARTS, modify gradient
                if args.sdarts:
                    so_grad_approx(model, 
                        data_arch, targets_arch, 
                        data_sw, targets_sw,
                        args.lr)
                else:
                    output = model(data_arch)
                    loss = F.nll_loss(output, targets_arch)
                    loss.backward()

                if args.warm_start:
                    torch.nn.utils.clip_grad_norm_(arch_params, 1.0)

                if args.fdarts or args.sdarts:
                    torch.nn.utils.clip_grad_norm_(model_params, 5.0)

                if not args.fixed:
                    arch_optimizer.step()

            # Take a step with the shared weights optimizer
            optimizer.zero_grad()
            if arch_optimizer:
                arch_optimizer.zero_grad()

            output = model(data_sw)
            loss = F.nll_loss(output, targets_sw)
            loss.backward()
            optimizer.step()

        elif args.unilevel:
            raise NotImplementedError

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data_sw), len(train_loader_sw.dataset),
                100. * batch_idx / len(train_loader_sw), loss.item()))
            if args.dry_run:
                break

def so_grad_approx(model, x_val, y_val, x_train, y_train, lr):
    eps = 0.01
    xi = lr

    model_orig = copy.deepcopy(model)

    # Get w'
    model_prime = copy.deepcopy(model)
    output = model_prime(x_train)
    loss = F.nll_loss(output, y_train)
    loss.backward()
    _, w = get_a_w_params(model_prime)
    for p in w:
        if p.grad is not None:
            p.data.add_(-xi * p.grad)
            # zero the gradients
            p.grad.mul_(0)

    # Get w+, w-
    model_pos = copy.deepcopy(model_prime)
    model_neg = copy.deepcopy(model_prime)
    output = model_prime(x_val)
    loss = F.nll_loss(output, y_val)
    loss.backward()
    _, w_prime = get_a_w_params(model_prime)
    _, w = get_a_w_params(model_orig)
    _, w_pos = get_a_w_params(model_pos)
    _, w_neg = get_a_w_params(model_neg)

    total_norm = 0.0
    for p in model_prime.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    eps /= total_norm

    for i in range(len(w_prime)):
        if w_prime[i].grad is not None:
            w_pos[i].data = w[i].data + eps * w_prime[i].grad
            w_neg[i].data = w[i].data - eps * w_prime[i].grad
            # zero the gradients
            w_prime[i].grad.mul_(0)

    # Get arch gradient of w+
    # Get arch gradient of w-
    # Get arch gradient of w'
    output = model_pos(x_train)
    loss = F.nll_loss(output, y_train)
    loss.backward()

    output = model_neg(x_train)
    loss = F.nll_loss(output, y_train)
    loss.backward()

    output = model_prime(x_val)
    loss = F.nll_loss(output, y_val)
    loss.backward()

    # Manually set the architecture gradients of model
    a, _ = get_a_w_params(model)
    a_prime, _ = get_a_w_params(model_prime)
    a_pos, _ = get_a_w_params(model_pos)
    a_neg, _ = get_a_w_params(model_neg)

    for i in range(len(a)):
        if a_prime[i].grad is not None:
            approx = (a_pos[i] - a_neg[i]) / (2 * eps)
            a[i].grad = a_prime[i].grad - (xi * approx)

def get_a_w_params(model):
    arch_params = []
    model_params = []
    for name, p in model.named_parameters():
        if ('twiddle' in name) or ('permutation' in name):
            arch_params.append(p)
        else:
            model_params.append(p)
    return arch_params, model_params


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc = correct / len(test_loader.dataset)
    return test_acc, test_loss
