import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from tqdm import tqdm
import os
import logging
from scipy.io import savemat

from arg_parser import parse_args

import sys
sys.path.append("../")
sys.path.append("../QCFS")
import CHT
import Models
from dst_scheduler import DSTScheduler
from utils import replace_maxpool2d_by_avgpool2d

def seed_all(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)


def optimizer_and_lrscheduler(model, args):
    if args.dataset.upper() == 'CIFAR10' or args.dataset.upper() == 'CIFAR100':
        if args.architecture == 'ResNet50_CIFAR':
            optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4, nesterov=True)
            num_warmup_epochs = 10
            scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[
                optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, end_factor=1.0, total_iters=num_warmup_epochs),
                optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs-num_warmup_epochs, eta_min=0.001)
            ], milestones=[num_warmup_epochs])
        else:
            optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
            num_warmup_epochs = 10
            scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[
                optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1 * 0.5, end_factor=1.0, total_iters=num_warmup_epochs),
                optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs-num_warmup_epochs, eta_min=0.001)
            ], milestones=[num_warmup_epochs])

    elif args.dataset.upper() == 'IMAGENET':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
        num_warmup_epochs = 10
        scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[
            optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=num_warmup_epochs),
            optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1),
        ], milestones=[num_warmup_epochs])
    elif args.dataset.upper() == 'TINY':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4, nesterov=True)
        num_warmup_epochs = 10
        scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[
            optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=num_warmup_epochs),
            optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs-num_warmup_epochs, eta_min=0.001)
        ], milestones=[num_warmup_epochs])
    else:
        raise NotImplementedError(f"Unsupported dataset: {args.dataset}.")
    
    return optimizer, scheduler

def test_model(model, test_loader, device):
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    epoch_loss=0
    with torch.no_grad():
        model.eval()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            epoch_loss+=criterion(output,target).item()
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    test_accuracy = correct / total
    
    return test_accuracy, epoch_loss/len(test_loader)



def train_model(model, pruner ,train_loader, test_loader, optimizer, epochs, scheduler, save, device):
    # Set device based on GPU parameter
    device = torch.device(device)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    
    # Initialize best model tracking
    best_val_acc = 0.
    val_acc_history, val_loss_history, train_acc_history, train_loss_history = [], [], [], []

    for epoch in range(epochs):
        model.train()
        total, correct, epoch_loss=0, 0, 0.0
        for batch_idx, (data, target) in tqdm(enumerate(train_loader),disable=True): #DEBUG
            #break; #DEBUG
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)

            total+=len(target)
            correct+=(target==output.max(1)[1]).sum().item() 
            epoch_loss+=loss.item()

            loss.backward()

            if pruner is not None:
                if pruner():
                    optimizer.step()
            else:
                optimizer.step()
            
        if scheduler is not None:
            scheduler.step()

        val_accuracy,val_loss=test_model(model,test_loader,device)
        val_acc_history.append(val_accuracy)
        val_loss_history.append(val_loss)
        train_acc_history.append(correct/total)
        train_loss_history.append(epoch_loss/len(train_loader))
        print(f'Epoch {epoch}, acc: {val_accuracy}, loss: {val_loss}')
                
        # Save best models based on both validation accuracy and validation loss
        # Track best model by accuracy
        if val_accuracy > best_val_acc and save!=None:
            save_path=os.path.join("../input",save)
            os.makedirs(save_path,exist_ok=True)
            if epoch<100: suffix = 'SNM'
            else: suffix = 'SNM_200'

            torch.save(model.state_dict(),os.path.join(save_path,'best_model.pth'+suffix))
            if pruner:
                torch.save(pruner.state_dict(),os.path.join("../input",save,'pruner.pth'+suffix))
            best_val_acc = val_accuracy

        with torch.no_grad():
            #model(data)  # Some params of the model are initialized lazily #NO NEED
            CHT.CHT_evolve(model)

    return np.array(val_acc_history), np.array(val_loss_history), np.array(train_acc_history), np.array(train_loss_history)


def main(args):
    device=torch.device(args.device)
    save=os.path.join(args.architecture,args.dataset,f'conv_{args.conv_sparsity}',f's_{args.linear_sparsity}',f'd_{args.dropout}/onefc_{args.one_fc}', f'lr_{args.lr}/bs_{args.bs}') if args.save else None

    model, num_activations, train_loader,test_loader=Models.prepare_model_and_loader(args)
    model.to(device)
    model = replace_maxpool2d_by_avgpool2d(model) #For SNM
 
    #optimizer, scheduler = optimizer_and_lrscheduler(model, args)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160, 200, 240], gamma=0.2)

    if args.linear_sparsity>0.0 and not args.one_fc:
        T_end = args.epochs * 0.75 if (args.adaptive_zeta or args.EM_S) else args.epochs
        pruner = DSTScheduler(model, optimizer, alpha=args.zeta, delta=args.update_interval * len(train_loader), sparsity_distribution=args.sparsity_distribution, static_topo=False, T_end=T_end* len(train_loader), ignore_linear_layers=False, grad_accumulation_n=1, args=args)
    else:
        pruner = None

    acc_val, loss_val, acc_train, loss_train = train_model(model, pruner ,train_loader, test_loader,optimizer ,args.epochs,  scheduler,save, device)
    
    if args.save:
        os.makedirs(os.path.join("../input",save),exist_ok=True)
        savemat(os.path.join("../input",save,'res.matSNM200'),{'acc_val': acc_val ,'loss_val':loss_val, 'acc_train': acc_train, 'loss_train': loss_train})


if __name__ == "__main__":
    args = parse_args()
    seed_all(args.seed)
    print(args)
    if args.conv_sparsity==0.0:
        assert args.linear_sparsity==0.0

    logging.basicConfig(
        filename="../input_cnn/error.log",
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
    )

    try:
        main(args)
    except Exception as e:
        logging.exception(f"exception in main\n{args}")
        sys.exit(1)



