from tqdm import tqdm
import torch
import os
from training.stats_utils import get_fs_stats, get_sharpness_stats
import wandb
import pickle
from collections import defaultdict

def train(args, model, criterion, optimizer, train_loader, val_loader, test_loader):
    stats = {}

    best_val_err, best_test_err = 100, 100
    args.batches_per_epoch = len(train_loader)

    args.ckpt_path = os.path.join(args.local_run_path,'ckpt')
    os.makedirs(args.ckpt_path)

    param_cache = {'param_views': [param.view(-1) for _, param in model.named_parameters()]}

    online_fs = defaultdict(list) if args.get_online_fs_stats else None

    loss_upweight_indices = torch.load(args.loss_upweight_idx_path).to(args.device) if args.loss_upweight_idx_path else None

    if not args.debug:
        wandb.init(entity=args.wandb_team,
                project=args.exp_name,
                name=args.run_name,
                config=args,
                dir=args.local_run_path)

        wandb.define_metric('epochs')
        wandb.define_metric('*', step_metric='epochs')

    # Epoch 0 is just eval
    for epoch in range(args.epochs+1):
        # Train
        if epoch:
            model.train()
            fs_set = set()
            
            for x, y, orig_idx, raw_idx in tqdm(train_loader, desc=f'Training epoch {epoch}'):
                optimizer.zero_grad()
                x, y = x.to(args.device), y.to(args.device)

                if args.get_sharpness_stats:
                    # ~50% slowdown (major slowdown), Updates 'stats' dict
                    get_sharpness_stats(args, stats, model, criterion, x, y, orig_idx, raw_idx)

                y_hat = model(x)
                loss = criterion(y_hat, y)
                _, y_pred = torch.max(y_hat, 1)
                acc = y_pred == y

                if online_fs is not None:
                    for orig_i,is_correct in zip(orig_idx,acc):
                        orig_i = orig_i.item()
                        is_correct = is_correct.item()
                        if orig_i in fs_set:
                            online_fs[orig_i][-1] = is_correct
                        else:
                            online_fs[orig_i].append(is_correct)
                            fs_set.add(orig_i)

                if args.get_fs_stats:
                    # ~2it/s slowdown (negligible slowdown), Updates 'stats' dict
                    get_fs_stats(args, stats, y_hat, y, acc, loss, orig_idx, raw_idx)

                if loss_upweight_indices is not None:
                    loss *= torch.where((orig_idx.unsqueeze(1).to(args.device) == loss_upweight_indices).any(dim=1), \
                                        args.loss_upweight_weight, \
                                        torch.tensor(1.0, dtype=torch.float32))
                
                loss = loss.mean() + l2_reg(args.l2_reg, param_cache)
                loss.backward()

                if args.sam_rho != 0.0:
                    optimizer.first_step()
                    sam_loss = criterion(model(x), y).mean() + l2_reg(args.l2_reg, param_cache)
                    sam_loss.backward()
                    optimizer.second_step()

                optimizer.step()

                epoch_train_loss += loss.item()
                epoch_correct += sum(acc).item()
                epoch_total += len(acc)

        lr_scheduler(optimizer, epoch, args.lr_epoch_interval, args.lr_max, args.lr_decay)

        if epoch == args.error_barrier_epoch and not args.aus:
            for _ in range(args.error_barrier_reroll):
                for _ in train_loader: break

        # Eval
        if epoch % args.eval_model_every_k_epochs == 0 or epoch<=10:
            model.eval()

            if epoch == 0:
                avg_train_loss, train_err = eval(args, model, criterion, train_loader, split='train', epoch=epoch, online_fs=online_fs)
            else:
                avg_train_loss, train_err = epoch_train_loss/len(train_loader), 100*(1-epoch_correct/epoch_total)

            avg_val_loss, val_err = eval(args, model, criterion, val_loader, split='val', epoch=epoch)
            avg_test_loss, test_err = eval(args, model, criterion, test_loader, split='test', epoch=epoch)
        
            stat_str = \
            (f'\n'
                f'Stats after {epoch} epochs:\n'
                f'\n'
                f'avg_train_loss:\t {avg_train_loss:.3f}\n'
                f'avg_val_loss:\t {avg_val_loss:.3f}\n'
                f'avg_test_loss:\t {avg_test_loss:.3f}\n'
                f'\n'
                f'train_err:\t {train_err:.3f}\n'
                f'val_err:\t {val_err:.3f}\n'
                f'test_err:\t {test_err:.3f}\n'
                f'\n'
                f'lr:\t\t {optimizer.param_groups[0]["lr"]:.6f}\n')
            
            print(stat_str)
            
            # Save best model
            if val_err < best_val_err:
                torch.save({'last': model.state_dict()}, os.path.join(args.ckpt_path, f'best_val_err.pt'))
                print(f'New best val err! {val_err:.3f} (ckpt saved).')
                best_val_err = val_err

            if test_err < best_test_err:
                torch.save({'last': model.state_dict()}, os.path.join(args.ckpt_path, f'best_test_err.pt'))
                print(f'New best test err! {test_err:.3f} (ckpt saved).')
                best_test_err = test_err

            # Save current model
            torch.save({'last': model.state_dict()}, os.path.join(args.ckpt_path, f'epochs={epoch}.pt'))
            print(f'Saved model after {epoch} epochs.\n')

            # Wandb
            if not args.debug:
                wandb.log({
                        'epochs': epoch,
                        'batches_per_epoch': args.batches_per_epoch,
                        'avg_train_loss': avg_train_loss,
                        'train_err': train_err,
                        'avg_val_loss': avg_val_loss,
                        'val_err': val_err,
                        'avg_test_loss': avg_test_loss,
                        'test_err': test_err,
                        'best_val_err': best_val_err,
                        'best_test_err': best_test_err,
                        'lr': optimizer.param_groups[0]['lr']
                    })
            
            epoch_train_loss, epoch_correct, epoch_total = 0, 0, 0

    # Save stats
    stats_dict = {'get_fs_stats':args.get_fs_stats,'get_sharpness_stats':args.get_sharpness_stats,'stats':stats}
    args.stats_dict_path = os.path.join(args.local_run_path,'stats_dict.pkl')
    with open(args.stats_dict_path, 'wb') as f:
        pickle.dump(stats_dict, f)

    if online_fs is not None:
        fs_online = {k:sum([v[i] and not v[i+1] for i in range(len(v)-1)]) for k,v in online_fs.items()}
        args.fs_dict_path = os.path.join(args.local_run_path,'fs_online.pkl')

        with open(args.fs_dict_path, 'wb') as f:
            pickle.dump(fs_online, f)

    if not args.debug: wandb.finish()

def eval(args, model, criterion, loader, split, epoch, online_fs=None):
    loss, correct, total = 0, 0, 0

    fs_set = set()
    with torch.no_grad():
        for x, y, orig_idx, _ in tqdm(loader, desc=f'{split.capitalize()} eval after {epoch} epochs'):
            x, y = x.to(args.device), y.to(args.device)
            y_hat = model(x)
            loss += criterion(y_hat, y).mean().item()

            _, y_pred = torch.max(y_hat, 1)
            acc = y_pred == y

            if online_fs is not None:
                for orig_i,is_correct in zip(orig_idx,acc):
                    orig_i = orig_i.item()
                    is_correct = is_correct.item()
                    if orig_i in fs_set:
                        online_fs[orig_i][-1] = is_correct
                    else:
                        online_fs[orig_i].append(is_correct)
                        fs_set.add(orig_i)

            correct += sum(acc).item()
            total += len(acc)

    avg_loss = loss/len(loader)
    err = 100*(1-correct/total)

    return avg_loss, err

def lr_scheduler(optimizer, current_epoch, lr_epoch_interval, lr_max, lr_decay):
    if isinstance(lr_epoch_interval, int): lr_epoch_interval = [lr_epoch_interval]

    if current_epoch in lr_epoch_interval:
        decay_step = sum(1 for interval in lr_epoch_interval if current_epoch >= interval)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_max * (lr_decay ** decay_step)

def l2_reg(l2_reg, cache):
    return l2_reg * 0.5 * torch.sum(torch.cat(cache['param_views']) ** 2).float()
