import re, os, shutil, random
import datetime
import numpy as np
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn as nn
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from model import MLP, ResNetMini
from dataloader import get_MNIST_loaders, get_CIFAR10_loader
import more_itertools



def save_model(state_dict, is_best, log_dir):
    torch.save(state_dict, log_dir+'/latest.pth')
    if is_best:
        torch.save(state_dict, log_dir+'/best.pth')


def save_py(log_dir, py_dir='./'):
    for filename in os.listdir(py_dir):
        if filename.endswith(".py"):
            src_path = os.path.join(py_dir, filename)
            dst_path = os.path.join(log_dir, 'codes', filename)
            os.makedirs(os.path.dirname(dst_path), exist_ok=True)
            shutil.copy(src_path, dst_path)


def gradient_cosine_similarity(x, y):
    assert len(x) == len(y)
    cosine_similarity = []
    for i in range(len(x)):
        assert type(x[i]) == type(y[i])
        if type(x[i]) == list:
            cosine_similarity.append(gradient_cosine_similarity(x[i], y[i]))
        else:
            assert type(x[i]) == torch.Tensor
            cosine_similarity.append(float(torch.cosine_similarity(x[i].flatten(), y[i].flatten(), dim=0)))
    return cosine_similarity


def validation(net, valid_loader, criterion, device):
    valid_loss, valid_accuracy, valid_count = 0., 0., 0
    net.eval()
    for inputs, labels in tqdm(valid_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        valid_accuracy += torch.sum(torch.where(labels == torch.argmax(outputs, dim=1), 1, 0)).cpu().detach().numpy()
        valid_loss += torch.sum(loss).cpu().detach().numpy()
        valid_count += inputs.shape[0]
    net.train()
    valid_loss /= valid_count
    valid_accuracy /= valid_count
    return valid_loss, valid_accuracy


def train(seed, dataset, is_lr, is_compare, is_dp, sigma, sigma_0, batch_size, repeat_n, total_clip_threshold, clip_threshlod_per_layer):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    device = torch.device(f'cuda:0')

    if dataset == 'MNIST':
        net = MLP(default_mode="logit").to(device)
    elif dataset == 'CIFAR10':
        net = ResNetMini(default_mode="weight").to(device)
    net.turn_off_antivariable()

    if dataset == 'MNIST':
        train_loader, valid_loader = get_MNIST_loaders(data_dir='../data/', batch_size=batch_size) 
    elif dataset == 'CIFAR10':
        train_loader, valid_loader = get_CIFAR10_loader(data_dir='../data/', batch_size=batch_size)

    optimizer = optim.Adam(net.parameters(), lr=1e-2)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.85)
    criterion = nn.CrossEntropyLoss(reduction='none')

    current_time = re.sub(r'\D', '', str(datetime.datetime.now())[4:-7])
    print(current_time)
    log_dir = f'./logs/{dataset}/'+ type(net).__name__ + '/LR_' + current_time if is_lr else f'./logs/{dataset}/' + type(net).__name__ + '/BP_' + current_time
    writer = SummaryWriter(log_dir=log_dir)


    net.train()
    best_accuracy = -1.
    layer_n = len(net.module_w_para)
    T = 0
    for epoch in range(scheduler.last_epoch, 25):
        train_loss, train_accuracy, train_count = 0., 0., 0
        if is_lr and is_compare:
            cosine_similarity_list = []

        
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for i, (inputs, labels) in pbar:
            T += 1
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            if is_lr:
                net.set_sigma(sigma)
                with torch.no_grad():
                    outputs = net(inputs)
                    loss_0 = criterion(outputs, labels)
                if is_dp:
                    with torch.no_grad():
                        net.dp_controller(sigma_0, repeat_n, total_clip_threshold, loss_0)
                net.eval()
                with torch.no_grad():
                    for l in range(layer_n):
                        inputs_ = inputs.repeat(int(repeat_n[l]), 1, 1, 1).to(device)
                        labels_ = labels.repeat(int(repeat_n[l])).to(device)
                        
                        add_noise = [True if j == l else False for j in range(layer_n)]
                        outputs = net(inputs_, add_noise)
                        loss = criterion(outputs, labels_)
                        if is_dp:
                            clip_threshold_l = clip_threshlod_per_layer[l] if clip_threshlod_per_layer is not None else total_clip_threshold/(layer_n)**0.5
                            net.backward(loss, grad_sample=True, clip_threshold=clip_threshold_l, batch_size=inputs.shape[0], loss0=loss_0)  
                        else:
                            net.backward(loss)
                        train_accuracy += torch.sum(torch.where(labels_ == torch.argmax(outputs, 1), 1, 0)).cpu().detach().numpy()
                        train_loss += torch.sum(loss).cpu().detach().numpy()
                        train_count += len(labels_)
                optimizer.step()

                net.train()
                if is_compare:
                    grad_lr = net.fetch_gradient()
                    optimizer.zero_grad()
                    criterion(net(inputs), labels).mean().backward()
                    grad_bp = net.fetch_gradient()
                    optimizer.zero_grad()
                    cosine_similarity_list.append(gradient_cosine_similarity(grad_bp, grad_lr))
                    mean_sim = np.mean([list(more_itertools.collapse(sim)) for sim in cosine_similarity_list], axis=0)
                    pbar.set_description(f"grad_sim: {np.round(mean_sim[:5], 4)}")
            else:
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.mean().backward()
                train_accuracy += torch.sum(torch.where(labels == torch.argmax(outputs, 1), 1, 0)).cpu().detach().numpy()
                train_loss += torch.sum(loss).cpu().detach().numpy()
                train_count += len(labels)
                optimizer.step()

        scheduler.step()

        # log and save
        train_loss /= train_count
        train_accuracy /= train_count
        valid_loss, valid_accuracy = validation(net, valid_loader, criterion, device)
        log_info = f'Train Epoch:{epoch:3d} || train loss:{train_loss:.2e} train accuracy:{train_accuracy*100:.2f}% ' + \
                f'valid loss:{valid_loss:.4e} valid accuracy:{valid_accuracy*100:.2f}% lr:{scheduler.get_last_lr()[0]:.2e} '
        save_model(net.state_dict(), valid_accuracy >= best_accuracy, log_dir)
        torch.save(optimizer.state_dict(), log_dir + '/optimizer.pth')
        torch.save(scheduler.state_dict(), log_dir + '/scheduler.pth')
        best_accuracy = deepcopy(valid_accuracy) if valid_accuracy >= best_accuracy else best_accuracy
        writer.add_scalar('loss/train_loss', train_loss, epoch)
        writer.add_scalar('loss/valid_loss', valid_loss, epoch)
        writer.add_scalar('accuracy/train_accuracy', train_accuracy, epoch)
        writer.add_scalar('accuracy/valid_accuracy', valid_accuracy, epoch)
        if is_lr and is_compare:
            log_info = log_info + f"grad_sim: {np.round(mean_sim, 4)}"
            for i, item in enumerate(mean_sim):
                writer.add_scalar('grad_sim/module_'+str(i), item, epoch)
        print(log_info)

