from tqdm import tqdm
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
from training.stats_utils import get_fs_stats, get_sharpness_stats
import wandb
import pickle
from collections import defaultdict
from data_utils import get_dataloaders

def train(args, model, criterion, optimizer, train_loader, val_loader, test_loader, val_criterion=None):
    if args.eiil:
        print(f"Train with lr_max = {args.lr_max}, weight_deay = {args.l2_reg}, epoch = {args.epochs}")
    if val_criterion is None:
        val_criterion = criterion
    # Define cosine lr
    cosine_lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=0, last_epoch=-1)

    # Prepare accelerate if using it
    if args.accelerate:
        from accelerate import Accelerator
        accelerator = Accelerator()

        device = accelerator.device
        model, optimizer, train_loader, val_loader, test_loader, cosine_lr_scheduler = accelerator.prepare(
            model, optimizer, train_loader, val_loader, test_loader, cosine_lr_scheduler
        )

    stats = {}

    best_val_err, best_test_err = 100, 100
    args.batches_per_epoch = len(train_loader)

    args.ckpt_path = os.path.join(args.local_run_path,'ckpt')
    os.makedirs(args.ckpt_path, exist_ok=True)

    param_cache = {'param_views': [param.view(-1) for _, param in model.named_parameters()]}

    online_fs = defaultdict(list) if args.get_online_fs_stats else None

    if not args.accelerate:
        loss_upweight_indices = torch.load(args.loss_upweight_idx_path).to(args.device) if args.loss_upweight_idx_path else None
    else:
        loss_upweight_indices = torch.load(args.loss_upweight_idx_path) if args.loss_upweight_idx_path else None

    if not args.debug:
        wandb.init(entity=args.wandb_team,
                project=args.exp_name,
                name=args.run_name,
                config=args,
                dir=args.local_run_path)

        wandb.define_metric('epochs')
        wandb.define_metric('*', step_metric='epochs')

    if args.crs:
        with open(args.crs_stats_path, 'rb') as f:
            proxy_stats = pickle.load(f)

        idx_to_norms = proxy_stats['gn']

        def calculate_local_average(numbers, i, bin_width):
            half_width = bin_width//2
            start_index = max(0, i-half_width)
            end_index = min(len(numbers), i+half_width+1)
            return sum(numbers[start_index:end_index]) / (end_index-start_index)

        for k,v in idx_to_norms.items():
            idx_to_norms[k] = [calculate_local_average(v,i,args.crs_epoch_bin) for i in range(len(v))]

        num_keys = len(idx_to_norms)
        cus_count = int(args.cus_percentile/100*num_keys)
        cds_count = int(args.cds_percentile/100*num_keys)

        cus_indices = [[] for _ in range(200)]
        cds_indices = [[] for _ in range(200)]

        for i in range(200):
            indexed_values = [(k,v[i]) for k,v in idx_to_norms.items()]
            
            top_sorted = sorted(indexed_values, key=lambda x:x[1], reverse=True)
            bottom_sorted = sorted(indexed_values, key=lambda x:x[1])
            
            cus_indices[i] = [k for k,_ in top_sorted[:cus_count]]
            cds_indices[i] = [k for k,_ in bottom_sorted[:cds_count]]

    # Epoch 0 is just eval
    for epoch in range(args.epochs+1):
        # Train
        if epoch:
            model.train()
            fs_set = set()

            if args.crs: # This does NOT work with accelerate currently!
                train_loader, _, _, _ = get_dataloaders(args, cus_indices[epoch-1], cds_indices[epoch-1])
            
            for x, y, orig_idx, raw_idx in tqdm(train_loader, desc=f'Training epoch {epoch}'):
                optimizer.zero_grad()
                if not args.accelerate:
                    x, y = x.to(args.device), y.to(args.device)

                if args.get_sharpness_stats:
                    # ~50% slowdown (major slowdown), Updates 'stats' dict
                    get_sharpness_stats(args, stats, model, criterion, x, y, orig_idx, raw_idx)

                y_hat = model(x)
                if args.group_dro:
                    groups = raw_idx.to(args.device) 
                    loss = criterion(y_hat, y, groups)
                else:
                    loss = criterion(y_hat, y)
                _, y_pred = torch.max(y_hat, 1)
                if args.use_mixup or args.use_cutmix:
                    acc = (y_pred == torch.argmax(y, dim=1))
                else:
                    acc = y_pred == y

                if online_fs is not None:
                    for orig_i,is_correct in zip(orig_idx,acc):
                        orig_i = orig_i.item()
                        is_correct = is_correct.item()
                        if orig_i in fs_set:
                            online_fs[orig_i][-1] = is_correct
                        else:
                            online_fs[orig_i].append(is_correct)
                            fs_set.add(orig_i)

                if args.get_fs_stats:
                    # ~2it/s slowdown (negligible slowdown), Updates 'stats' dict
                    get_fs_stats(args, stats, y_hat, y, acc, loss, orig_idx, raw_idx)

                if loss_upweight_indices is not None:
                    if not args.accelerate:
                        loss *= torch.where((orig_idx.unsqueeze(1).to(args.device) == loss_upweight_indices).any(dim=1), \
                                            args.loss_upweight_weight, \
                                            torch.tensor(1.0, dtype=torch.float32))
                    else:
                        loss *= torch.where((orig_idx.unsqueeze(1) == loss_upweight_indices).any(dim=1), \
                            args.loss_upweight_weight, \
                            torch.tensor(1.0, dtype=torch.float32))
                
                loss = loss.mean() + l2_reg(args.l2_reg, param_cache)
                if not args.accelerate:
                    loss.backward()
                else:
                    accelerator.backward(loss)

                if args.sam_rho != 0.0:
                    optimizer.first_step()
                    if args.group_dro:
                        groups = raw_idx.to(args.device) 
                        sam_loss = criterion(model(x), y, groups).mean() + l2_reg(args.l2_reg, param_cache)
                    else:
                        sam_loss = criterion(model(x), y).mean() + l2_reg(args.l2_reg, param_cache)
                
                    if not args.accelerate:
                        sam_loss.backward()
                    else:
                        accelerator.backward(sam_loss)
                    optimizer.second_step()

                optimizer.step()
                if args.lr_scheduler == 'cosine':
                    cosine_lr_scheduler.step()

                epoch_train_loss += loss.item()
                epoch_correct += sum(acc).item()
                epoch_total += len(acc)

        if args.lr_scheduler == 'piecewise':
            lr_scheduler(optimizer, epoch, args.lr_epoch_interval, args.lr_max, args.lr_decay)

        if epoch == args.error_barrier_epoch and not args.aus:
            for _ in range(args.error_barrier_reroll):
                for _ in train_loader: break

        # Eval
        if epoch % args.eval_model_every_k_epochs == 0 or epoch<=10:
            model.eval()

            if epoch == 0:
                avg_train_loss, train_err = eval(args, model, val_criterion, train_loader, split='train', epoch=epoch, online_fs=online_fs)
            else:
                avg_train_loss, train_err = epoch_train_loss/len(train_loader), 100*(1-epoch_correct/epoch_total)

            if val_loader:
                avg_val_loss, val_err = eval(args, model, val_criterion, val_loader, split='val', epoch=epoch)
            else:
                avg_val_loss = 0.
                val_err = 0.
            avg_test_loss, test_err = eval(args, model, val_criterion, test_loader, split='test', epoch=epoch)
        
            stat_str = \
            (f'\n'
                f'Stats after {epoch} epochs:\n'
                f'\n'
                f'avg_train_loss:\t {avg_train_loss:.3f}\n'
                f'avg_val_loss:\t {avg_val_loss:.3f}\n'
                f'avg_test_loss:\t {avg_test_loss:.3f}\n'
                f'\n'
                f'train_err:\t {train_err:.3f}\n'
                f'val_err:\t {val_err:.3f}\n'
                f'test_err:\t {test_err:.3f}\n'
                f'\n'
                f'lr:\t\t {optimizer.param_groups[0]["lr"]:.6f}\n')
            
            print(stat_str)
            
            # Save best model
            if val_err < best_val_err:
                torch.save({'last': model.state_dict()}, os.path.join(args.ckpt_path, f'best_val_err.pt'))
                print(f'New best val err! {val_err:.3f} (ckpt saved).')
                best_val_err = val_err

            if test_err < best_test_err:
                torch.save({'last': model.state_dict()}, os.path.join(args.ckpt_path, f'best_test_err.pt'))
                print(f'New best test err! {test_err:.3f} (ckpt saved).')
                best_test_err = test_err

            # Save current model
            if not args.accelerate:
                torch.save({'last': model.state_dict()}, os.path.join(args.ckpt_path, f'epochs={epoch}.pt'))
                print(f'Saved model after {epoch} epochs.\n')
            else:
                accelerator.wait_for_everyone()
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(os.path.join(args.ckpt_path, f'epochs={epoch}.pt'), save_function=accelerator.save)

            # Wandb
            if not args.debug:
                wandb.log({
                        'epochs': epoch,
                        'batches_per_epoch': args.batches_per_epoch,
                        'avg_train_loss': avg_train_loss,
                        'train_err': train_err,
                        'avg_val_loss': avg_val_loss,
                        'val_err': val_err,
                        'avg_test_loss': avg_test_loss,
                        'test_err': test_err,
                        'best_val_err': best_val_err,
                        'best_test_err': best_test_err,
                        'lr': optimizer.param_groups[0]['lr']
                    })
            
            epoch_train_loss, epoch_correct, epoch_total = 0, 0, 0

    # Save stats
    stats_dict = {'get_fs_stats':args.get_fs_stats,'get_sharpness_stats':args.get_sharpness_stats,'stats':stats}
    args.stats_dict_path = os.path.join(args.local_run_path,'stats_dict.pkl')
    with open(args.stats_dict_path, 'wb') as f:
        pickle.dump(stats_dict, f)

    if online_fs is not None:
        fs_online = {k:sum([v[i] and not v[i+1] for i in range(len(v)-1)]) for k,v in online_fs.items()}
        args.fs_dict_path = os.path.join(args.local_run_path,'fs_online.pkl')

        with open(args.fs_dict_path, 'wb') as f:
            pickle.dump(fs_online, f)

    if not args.debug: wandb.finish()

def eval(args, model, criterion, loader, split, epoch, online_fs=None):
    loss, correct, total = 0, 0, 0

    fs_set = set()
    with torch.no_grad():
        for x, y, orig_idx, _ in tqdm(loader, desc=f'{split.capitalize()} eval after {epoch} epochs'):
            if not args.accelerate:
                x, y = x.to(args.device), y.to(args.device)
            y_hat = model(x)
            loss += criterion(y_hat, y).mean().item()

            _, y_pred = torch.max(y_hat, 1)
            if (args.use_mixup or args.use_cutmix) and (split == 'train'):
                acc = (y_pred == torch.argmax(y, dim=1))
            else:
                acc = y_pred == y

            if online_fs is not None:
                for orig_i,is_correct in zip(orig_idx,acc):
                    orig_i = orig_i.item()
                    is_correct = is_correct.item()
                    if orig_i in fs_set:
                        online_fs[orig_i][-1] = is_correct
                    else:
                        online_fs[orig_i].append(is_correct)
                        fs_set.add(orig_i)

            correct += sum(acc).item()
            total += len(acc)

    avg_loss = loss/len(loader)
    err = 100*(1-correct/total)

    return avg_loss, err

def lr_scheduler(optimizer, current_epoch, lr_epoch_interval, lr_max, lr_decay):
    if isinstance(lr_epoch_interval, int): lr_epoch_interval = [lr_epoch_interval]

    if current_epoch in lr_epoch_interval:
        decay_step = sum(1 for interval in lr_epoch_interval if current_epoch >= interval)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_max * (lr_decay ** decay_step)

def l2_reg(l2_reg, cache):
    return l2_reg * 0.5 * torch.sum(torch.cat(cache['param_views']) ** 2).float()
