import os
import numpy as np
import json
import torch
import torch.nn as nn
import torch.optim as optim
import wandb

from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

from dataloader import get_imagenet_dataloaders, cifar_dataset, get_text_dataloaders

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def adjust_learning_rate(lr, optimizer, epoch):
    lr = lr * (0.1 ** (epoch // 30))
    # if epoch == 40 or epoch == 70:
    #     lr = lr * 0.1
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def adjust_weight_decay(weight_decay, optimizer, epoch):
    if epoch == 25 or epoch == 50:
        weight_decay = weight_decay * 0.5
    for param_group in optimizer.param_groups:
        param_group['weight_decay'] = weight_decay
        
def compute_grad_norm(model, norm_type=2):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1.0 / norm_type)
    return total_norm

def compute_weight_norm(model, norm_type=2):
    total_norm = 0.0
    for p in model.parameters():
        param_norm = p.data.norm(norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1.0 / norm_type)
    return total_norm

def validate(step, lin_conv_model, val_loader, loss_fn):
    lin_conv_model = lin_conv_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc = 'Iterating over validation set...', dynamic_ncols = True):
            img, label = batch[0]['images'], batch[0]['labels'].squeeze().long()
            img = img.permute(0, 3, 1, 2)
            preds = lin_conv_model(img)
            loss = loss_fn(preds, label)
            val_loss += loss.item() * img.shape[0]
    avg_val_loss = val_loss / val_loader.size
    print(f'Step {step + 1}, Validation Loss {avg_val_loss}')
    top1_acc, top5_acc = eval_loop(lin_conv_model, val_loader, device)
    print(f'Step {step + 1}, Top-1 Accuracy {top1_acc}, Top-5 Accuracy {top5_acc}')
    wandb.log({'top1_acc': top1_acc, 'val_loss': avg_val_loss})
    return avg_val_loss, top1_acc, top5_acc 

def imagenet_finetune(args, lin_conv_model, orig):
    train_loader, val_loader = get_imagenet_dataloaders(batch_size = args.ft_batch)
    if hasattr(lin_conv_model, 'fc') and isinstance(lin_conv_model.fc, nn.Linear) and not args.pretrained:
        nn.init.normal_(lin_conv_model.fc.weight, mean=0.0, std=0.02)
        if lin_conv_model.fc.bias is not None: nn.init.zeros_(lin_conv_model.fc.bias)
    if args.tune_linear:
        for p in lin_conv_model.modules():
            if isinstance(p, nn.Linear): 
                if p.out_features != 1000:
                    p.requires_grad = True
            elif isinstance(p, nn.BatchNorm2d):
                p.requires_grad = True

    else:
        for p in lin_conv_model.parameters(): 
            p.requires_grad = True
    loss_fn = nn.CrossEntropyLoss(label_smoothing = 0.05)
    if args.model == 'rn18':
        optimizer = torch.optim.AdamW(
        [
            {'params': [p for n,p in lin_rn18.named_parameters()
                        if not any(t in n for t in ('bias','bn','ln','norm'))],
            'weight_decay': args.ft_wd},
            {'params': [p for n,p in lin_rn18.named_parameters()
                        if any(t in n for t in ('bias','bn','ln','norm'))],
            'weight_decay': 0.0},
        ],
        lr=args.ft_lr, betas=(0.9,0.999), eps=1e-8
        )
    else:
        optimizer = optim.AdamW(lin_conv_model.parameters(), lr = args.ft_lr, weight_decay = args.ft_wd)
    if args.use_scheduler:
        total_steps  = len(train_loader) * args.ft_epochs
        if args.scheduler_type == 'cosine': 
            steps_per_epoch = len(train_loader)
            total_steps  = args.ft_epochs * steps_per_epoch
            warmup_steps  = args.warmup_epochs * steps_per_epoch
            cosine_steps  = total_steps - warmup_steps
            sched_warmup = optim.lr_scheduler.LinearLR(optimizer, start_factor = 0.1, end_factor = 1.0, total_iters = warmup_steps)
            sched_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = cosine_steps)  
            scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers = [sched_warmup, sched_cosine], milestones = [warmup_steps])
        elif args.scheduler_type == 'onecycle':
            max_lr = 3*args.ft_lr
            scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = max_lr, total_steps = total_steps, pct_start = 0.03, anneal_strategy = 'cos', 
                                                            div_factor=5, final_div_factor=1000)
        elif args.scheduler_type == 'lambda':
            warmup_steps = int(0.05 * total_steps)
            def lr_lambda(step):
                if step < warmup_steps:
                    return step / warmup_steps              
                q = (step - warmup_steps) / (total_steps - warmup_steps)
                return 0.5 * (1 + math.cos(math.pi * q))
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        else:
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.1, patience = 10)
    else:
        scheduler = None
    lin_conv_model = lin_conv_model.to(device)
    val_losses = []
    train_losses = []
    top1_accs = []
    top5_accs = []
    weight_norms = []
    grad_norms = []
    log_json = {}
    scaler = GradScaler(enabled=args.use_amp)
    for epoch in range(args.ft_epochs):
        if scheduler == None:
            adjust_learning_rate(args.ft_lr, optimizer, epoch)
        if args.evaluate == 'epoch':
            avg_val_loss, top1_acc, top5_acc = validate(epoch, lin_conv_model, val_loader, loss_fn) 
            val_losses.append(avg_val_loss)
            top1_accs.append(top1_acc)
            top5_accs.append(top5_acc)
            if args.scheduler_type == 'plateau':
                scheduler.step(avg_val_loss)
        train_loss = 0.0
        lin_conv_model.train()
        for i, batch in enumerate(tqdm(train_loader, desc = 'Iterating over training set...', dynamic_ncols = True)):
            if args.evaluate == 'steps':
                if (epoch * train_loader.size + i) % args.eval_steps == 0:
                    step = i + train_loader.size * epoch
                    avg_val_loss, top1_acc, top5_acc = validate(step, lin_conv_model, val_loader, loss_fn) 
                    val_losses.append(avg_val_loss)
                    top1_accs.append(top1_acc)
                    top5_accs.append(top5_acc)
                    if args.scheduler_type == 'plateau':
                        scheduler.step(avg_val_loss)
                    lin_conv_model.train()
            img, label = batch[0]['images'], batch[0]['labels'].squeeze().long()
            img = img.permute(0, 3, 1, 2)
            with autocast(enabled = args.use_amp):
                preds = lin_conv_model(img)
                loss = loss_fn(preds, label)
            optimizer.zero_grad()
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            grad_norm = compute_grad_norm(lin_conv_model)
            weight_norm = compute_weight_norm(lin_conv_model)
            weight_norms.append(weight_norm)
            grad_norms.append(grad_norm)
            scaler.step(optimizer)
            scaler.update()            
            train_losses.append(loss.item())
            wandb.log({'train_loss': np.mean(train_losses[-20:])})
            train_loss += loss.item() * img.shape[0]
        if scheduler and args.scheduler_type != 'plateau':
            scheduler.step()
        avg_train_loss = train_loss / train_loader.size
        print(f'Epoch {epoch + 1}, Training Loss {avg_train_loss}')
    top1_acc, top5_acc = eval_loop(lin_conv_model, val_loader, device)
    top1_accs.append(top1_acc)
    top5_accs.append(top5_acc)
    print(f'Epoch {epoch + 1}, Top-1 Accuracy {top1_acc}, Top-5 Accuracy {top5_acc}')
    log_json = {'train_losses': train_losses, 'val_losses': val_losses, 'top1_accs': top1_accs, 'top5_accs': top5_accs, 'grad_norms': grad_norms, 'weight_norms': weight_norms}
    if not os.path.exists(f'logs/{args.exp_name}'):
        os.makedirs(f'logs/{args.exp_name}')
    with open(f'logs/{args.exp_name}/ft_results.json', 'w') as f:
        json.dump(log_json, f)
    orig_top1_acc, orig_top5_acc = eval_loop(orig, val_loader, device)
    return lin_conv_model, max(top1_accs), log_json, orig_top1_acc

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = torch.topk(output, maxk, dim = 1, largest = True, sorted = True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def eval_loop(model, val_loader, device):
    model.eval()
    top1_acc = 0
    top5_acc = 0
    total_samples = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc = 'Iterating over test batches...', dynamic_ncols = True):
            img, label = batch[0]['images'], batch[0]['labels']
            img = img.permute(0, 3, 1, 2)
            outputs = model(img)
            acc1, acc5 = accuracy(outputs, label, topk=(1, 5))
            top1_acc += acc1.item() * img.size(0)
            top5_acc += acc5.item() * img.size(0)
            total_samples += img.size(0)
    top1_acc /= total_samples
    top5_acc /= total_samples
    return top1_acc, top5_acc

def eval_accuracy(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

def cifar_finetune(exp_name, model, orig):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    opt = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    train_loader, test_loader = cifar_dataset()
    accs = []
    ckpt_count = 0
    for epoch in range(1):
        epoch_losses = []
        model.train()
        for i, batch in enumerate(tqdm(train_loader, desc = 'Iterating over CIFAR')):
            x, y = batch
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            opt.step()
            epoch_losses.append(loss.item())
        print(f'Epoch {epoch+1}: train loss = {np.mean(epoch_losses):.3f}')
        acc = eval_accuracy(model, test_loader)
        accs.append(acc)
        print(f'Epoch {epoch+1}: test acc = {acc:.3f}')
    orig_acc = eval_accuracy(orig, test_loader)
    return max(accs), orig_acc

def shift_logits_labels(input_ids, lm_logits):
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    shift_labels = input_ids[..., 1:].contiguous()
    shift_labels = shift_labels.view(-1)
    return shift_logits, shift_labels

def get_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

def wikitext_finetune(args, rnn_model, orig):
    train_loader, val_loader, test_loader, vocab_size = get_text_dataloaders(args.model, args.ft_batch, args.num_workers, args.seq_len)
    if args.tune_rnn:
        for p in rnn_model.modules():
            if isinstance(p, nn.RNN): 
                p.requires_grad = True
            elif isinstance(p, nn.LayerNorm):
                p.requires_grad = True

    else:
        for p in rnn_model.parameters(): 
            p.requires_grad = True

    loss_fn = nn.CrossEntropyLoss(ignore_index, label_smoothing = 0.1)
    optimizer = optim.AdamW(lin_conv_model.parameters(), lr = args.ft_lr, weight_decay = args.ft_wd)

    if args.use_scheduler:
        if args.scheduler_type == 'cosine': 
            steps_per_epoch = len(train_loader)
            total_steps  = args.ft_epochs * steps_per_epoch
            warmup_steps  = args.warmup_epochs * steps_per_epoch
            cosine_steps  = total_steps - warmup_steps
            sched_warmup = optim.lr_scheduler.LinearLR(optimizer, start_factor = 0.1, end_factor = 1.0, total_iters = warmup_steps)
            sched_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = cosine_steps)  
            scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers = [sched_warmup, sched_cosine], milestones = [warmup_steps])
        else:
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.1, patience = 10)
    else:
        scheduler = None

    rnn_model = rnn_model.to(device)
    val_losses = []
    train_losses = []
    perplexities = []
    log_json = {}

    for epoch in range(num_epochs):
        rnn_model.eval()
        valid_loss = 0.0
        hidden = None
        for batch in tqdm(valid_loader, desc = 'Iterating over validation set...'):
            inputs = {'input_ids': batch['input_ids'].to(device)}
            with torch.no_grad():
                lm_logits = rnn_model(**inputs).logits
            shift_logits, shift_labels, lengths = shift_logits_labels(input_ids, lm_logits)
            loss = loss_fn(shift_logits, shift_labels)
            valid_loss += loss.item()
        avg_val_loss = valid_loss / len(valid_loader)
        wandb.log({'val_loss': avg_val_loss})
        val_losses.append(avg_val_loss)
        if avg_val_loss <= min(val_losses):
            torch.save(model.state_dict(), f'saved_models/{exp_name}.pt')
        print(f'Epoch {epoch + 1}, Validation Loss {avg_val_loss}')

        perplexity = eval_ppl_runner(rnn_model, test_loader, device)
        print(f'Wikitext-103 Test Perplexity {perplexity}')
        perplexities.append(perplexity)

        model.train()
        train_loss = 0.0
        for i, batch in enumerate(tqdm(train_loader, desc = 'Iterating over train loader')):
            inputs = {'input_ids': batch['input_ids'].to(device)}
            lm_logits = rnn_model(**inputs).logits
            shift_logits, shift_labels, lengths = shift_logits_labels(task, input_ids, lm_logits, batch, lengths, device)
            loss = loss_fn(shift_logits, shift_labels)
            train_losses.append(loss.item())
            if i % 20 == 0:
                avg_train_loss = np.mean(train_losses[-20:])
                wandb.log({'ce_loss': avg_ce_loss})
            before_update_params = {name: param.clone() for name, param in model.named_parameters()}
            loss.backward()   
            grad_norm = get_grad_norm(model)
            wandb.log({'grad_norm': grad_norm})
            nn.utils.clip_grad_norm_(model.parameters(), max_norm = 0.25)
            if (i + 1) % accumulation == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            train_loss += loss.item()
        avg_train_loss = train_loss / len(train_loader)
        epoch_train_losses.append(avg_train_loss)
        print(f'Epoch {epoch + 1}, Training Loss {avg_train_loss}')

    perplexity = eval_ppl_runner(rnn_model, test_loader, device)
    print(f'Wikitext-103 Test Perplexity {perplexity}')
    perplexities.append(perplexity)

    log_json = {'train_losses': train_losses, 'val_losses': val_losses, 'test_perplexity': perplexities}
    if not os.path.exists(f'logs/{args.exp_name}'):
        os.makedirs(f'logs/{args.exp_name}')
    with open(f'logs/{args.exp_name}/ft_results.json', 'w') as f:
        json.dump(log_json, f)

    orig_perplexity = 37.50
    return rnn_model, min(perplexities), log_json, orig_perplexity