import os 
import random
import math
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import pandas as pd

import torch
from torch import nn
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from jiwer import wer

from sam import SAM

def set_seed(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    torch.backends.cudnn.deterministic=True


def setup_optimizer(params, opt_name='AdamW', lr=1e-4, beta=0.9, weight_decay=0., scheduler=None, step_size=1, gamma=0.7):
    opt = getattr(torch.optim, opt_name)
    print(f'[INFO]    optimizer: {opt}')
    print(f'[INFO]    scheduler: {scheduler}')
    if opt_name == 'Adam':       
        optimizer = opt(params,
                lr=lr,
                betas=(beta, 0.999),
                weight_decay=weight_decay)
    else: 
        optimizer = opt(params, lr=lr, weight_decay=weight_decay)
    
    if scheduler is not None: 
        return optimizer, eval(scheduler)(optimizer, step_size=step_size, gamma=gamma)
    else: 
        return optimizer, None


def softmax_entropy(x, dim=2):
    # Entropy of softmax distribution from logits
    return -(x.softmax(dim) * x.log_softmax(dim)).sum(dim)

def mcc_loss(x, reweight=False, dim=2, class_num=32):
    p = x.softmax(dim) # (1, L, D)
    p = p.squeeze(0) # (L, D)
    if reweight: # (1, L, D) * (L, 1) 
        target_entropy_weight = softmax_entropy(x, dim=2).detach().squeeze(0) # instance-wise entropy (1, L, D)
        target_entropy_weight = 1 + torch.exp(-target_entropy_weight) # (1, L)
        target_entropy_weight = x.shape[1] * target_entropy_weight / torch.sum(target_entropy_weight)
        cov_matrix_t = p.mul(target_entropy_weight.view(-1, 1)).transpose(1, 0).mm(p)
    else:    
        cov_matrix_t = p.transpose(1, 0).mm(p) # (D, L) * (L, D) -> (D, D)

    cov_matrix_t = cov_matrix_t / torch.sum(cov_matrix_t, dim=1)
    mcc_loss = (torch.sum(cov_matrix_t) - torch.trace(cov_matrix_t)) / class_num
   
    return mcc_loss


def collect_params(model, bias_only=False, train_feature=False, train_all=False, train_LN=True):
    """Collect the affine scale + shift parameters from batch norms.

    Walk the model's modules and collect all batch normalization parameters.
    Return the parameters and their names.

    Note: other choices of parameterization are possible!
    """
    params = []
    names = []
    trainable = []
    if bias_only:
        trainable = ['bias']
    else: 
        trainable = ['weight', 'bias']

        
    for nm, m in model.named_modules():
        print(nm)
        if train_LN: 
            if isinstance(m, nn.LayerNorm) or isinstance(m, nn.GroupNorm):
                for np, p in m.named_parameters():
                    if f"{nm}.{np}" in names:
                            continue
                    if np in trainable:  
                        p.requires_grad = True
                        params.append(p)
                        names.append(f"{nm}.{np}")
        if train_feature:
            if len(str(nm).split('.')) > 1:
                if str(nm).split('.')[1] == 'feature_extractor' or str(nm).split('.')[1] == 'feature_projection':
                    for np, p in m.named_parameters():
                        if f"{nm}.{np}" in names:
                            continue
                        p.requires_grad = True
                        params.append(p)
                        names.append(f"{nm}.{np}")
                        
        if train_all: 
            for np, p in m.named_parameters():
                p.requires_grad = True
                params.append(p)
                names.append(f"{nm}.{np}")
            

    return params, names


def copy_model_and_optimizer(model, optimizer, scheduler):
    """Copy the model and optimizer states for resetting after adaptation."""
    model_state = deepcopy(model.state_dict())
    optimizer_state = deepcopy(optimizer.state_dict())
    if scheduler is not None:
        scheduler_state = deepcopy(scheduler.state_dict())
        return model_state, optimizer_state, scheduler_state
    else:
        return model_state, optimizer_state, None

def load_model_and_optimizer(model, optimizer, model_state, optimizer_state, scheduler_state):
    """Restore the model and optimizer states from copies."""
    model.load_state_dict(model_state, strict=True)
    optimizer.load_state_dict(optimizer_state)
    if scheduler is not None:
        scheduler.load_state_dict(scheduler_state)
        return model, optimizer, scheduler
    else: 
        return model, optimizer, None
    

def configure_model(model):
    """Configure model for use with tent."""
    model.requires_grad_(False)
    return model


def update_ema(ema, new_data):
    if ema is None:
        return new_data
    else:
        with torch.no_grad():
            return 0.9 * ema + (1 - 0.9) * new_data
        

def forward_and_adapt(x, model, optimizer, em_coef=0.9, reweight=False, temp=1., model_state=None, optimizer_state=None, repeat_inference=True, ema=None):
    """Forward and adapt model 
    """
    # forward

    outputs = model(x).logits
    # adapt
    loss = 0
    margin = math.log(32)*0.6

    entropys = softmax_entropy(outputs / temp)
    filter_ids_1 = torch.where(entropys < margin, 1, 0).bool()
    entropys = entropys[filter_ids_1]

    loss = entropys.mean(0).mean()
  
    loss.backward()
    optimizer.first_step(zero_grad=True)
    # second forward
    sec_outputs = model(x).logits
    sec_entropys = softmax_entropy(sec_outputs / temp)
    sec_entropys = sec_entropys[filter_ids_1]
    filter_ids_2 = torch.where(sec_entropys < margin, 1, 0).bool()
    sec_entropys = sec_entropys[filter_ids_2]

    loss_second = sec_entropys.mean(0).mean()
    loss_second.backward()
    optimizer.second_step(zero_grad=True)

    if not np.isnan(loss_second.item()):
        ema = update_ema(ema, loss_second.item())  # record moving average loss values for model recovery

    if ema is not None:
        if ema < 0.2:
            print("ema < 0.2, now reset the model")

    model.zero_grad()

    # inference again
    if repeat_inference:
        with torch.no_grad():
            outputs = model(x).logits
    return outputs, ema 

import argparse

if __name__ == '__main__':
    SAMPLE_RATE = 16000
    parser = argparse.ArgumentParser(description="TTA ASR")
    parser.add_argument('--asr', type=str, default="facebook/wav2vec2-base-960h")
    parser.add_argument('--steps', type=int, default=40)
    parser.add_argument('--episodic', action='store_true')
    parser.add_argument('--div_coef', type=float, default=0.)
    parser.add_argument('--opt', type=str, default='AdamW')
    parser.add_argument('--dataset_name', type=str, default='librispeech')
    parser.add_argument('--dataset_dir', type=str, default='/path/to/LibriSpeech')
    parser.add_argument('--split', default=['test-other'])
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--em_coef', type=float, default=1.)
    parser.add_argument('--reweight', action='store_true')
    parser.add_argument('--bias_only', action='store_true')
    parser.add_argument('--train_feature', action='store_true')
    parser.add_argument('--train_all', action='store_true')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--temp', type=float, default=2.5)
    parser.add_argument('--non_blank', action='store_true')
    parser.add_argument('--log_dir', type=str, default='./exps')
    parser.add_argument('--extra_noise', type=float, default=0.)
    parser.add_argument('--noise_type', default=None)
    parser.add_argument('--noise_snr', default=None)
    parser.add_argument('--scheduler', default=None)

    args = parser.parse_args()
    asr = args.asr
    steps = args.steps
    episodic = args.episodic
    opt = args.opt
    dataset_dir = args.dataset_dir
    dataset_name = args.dataset_name
    split = args.split
    lr = args.lr
    em_coef = args.em_coef
    reweight = args.reweight
    batch_size = args.batch_size
    temp =  args.temp
    non_blank = args.non_blank
    log_dir = args.log_dir
    extra_noise = args.extra_noise
    noise_type = args.noise_type
    noise_snr = args.noise_snr
    scheduler = args.scheduler
    div_coef = args.div_coef
    bias_only = args.bias_only
    train_feature = args.train_feature
    train_all = args.train_all
    skip_short_thd = None
    train_LN = True

    set_seed(42)

    #exp_name = dataset_name+'_'+str(lr)+'_'+str(em_coef)+'_'+str(steps)+'_'+str(temp)+'_'+asr.split('/')[-1]+'_'+'non_blank'+str(non_blank)+'_noise_'+str(extra_noise)+'_rew_'+str(reweight)
    exp_name = dataset_name+'_'+str(lr)+'_'+str(em_coef)+'_'+str(steps)+'_'+str(temp)+'_'+asr.split('/')[-1]+'_'+'non_blank'+str(non_blank)+'_noise_'+str(extra_noise)+'_rew_'+str(reweight)+'_ntype_'+str(noise_type)+'_nsnr_'+str(noise_snr)

    from data import load_dataset
    # dataset = load_dataset(split, dataset_name, dataset_dir, batch_size, extra_noise)
    dataset = load_dataset(dataset_name, dataset_dir, batch_size, extra_noise, noise_type, noise_snr)
    transcriptions_1 = []
    transcriptions_3 = []
    transcriptions_5 = []
    transcriptions_10 = []
    transcriptions_20 = []
    transcriptions_40 = []
    gt_texts = []
    ori_transcriptions = []
    durations = []
    werrs = []

    print('------------------------------------')
    print(f'exp: {exp_name}')
    print(f'eposidic? {episodic}')
    print(f'lr = {lr}')
    print(f'optim = {opt}')
    print(f'step = {steps}')
    print(f'em_coef = {em_coef}')
    print(f'reweight = {reweight}')
    print(f'batch size = {batch_size}')
    print(f'temperature = {temp}')
    print(f'non_blank = {str(non_blank)}')
    print(f'extra_noise = {extra_noise}')
    print(f'noise_type = {noise_type}')
    print(f'noise_snr = {noise_snr}')
    print(f'scheduler = {str(scheduler)}')
    print(f'div_coef = {str(div_coef)}')
    print(f'bias_only = {bias_only}')
    print(f'train_feature = {train_feature}')
    print(f'train_all = {train_all}')
    print(f'train_LN = {train_LN}')

    # load model and tokenizer
    processor = Wav2Vec2Processor.from_pretrained(asr, sampling_rate=SAMPLE_RATE, return_attention_mask=True)
    model = Wav2Vec2ForCTC.from_pretrained(asr).eval().cuda()        
    
    model = configure_model(model)
    params, param_names = collect_params(model, bias_only, train_feature, train_all, train_LN)

    base_optimizer = torch.optim.AdamW
    optimizer = SAM(params, base_optimizer, lr=lr, weight_decay=0.)

    if episodic: 
        model_state, optimizer_state, scheduler_state = copy_model_and_optimizer(model, optimizer, scheduler)

    
    print(param_names)
    count = 0

    import time
    start = time.time()
    for batch in tqdm(dataset):
        lens, wavs, texts, files = batch
        
        inputs = processor(wavs, return_tensors="pt", padding="longest")
        input_values = inputs.input_values.cuda()
        duration = input_values.shape[1] / SAMPLE_RATE
        durations.append(duration)
        
        if episodic: 
            model, optimizer, scheduler = load_model_and_optimizer(model, optimizer, model_state, optimizer_state, scheduler_state)
        
        # vanilla forward 
        with torch.no_grad():
            outputs = model(input_values).logits
        predicted_ids = torch.argmax(outputs, dim=-1)
        ori_transcription = processor.batch_decode(predicted_ids)
        ori_transcriptions += ori_transcription
        ori_wer = wer(list(texts), list(ori_transcription))
        print("original WER: ", ori_wer)

        
        if skip_short_thd is not None: 
            if outputs.shape[1] <= skip_short_thd:
                print(f'do not adapt since length is {outputs.shape[1]}')
                count += 1
                continue
        
        ema = None 
        for i in range(steps):
            outputs, ema = forward_and_adapt(input_values, model, optimizer, em_coef, reweight, temp, model_state, optimizer_state, True, ema)
            if episodic: 
                if i == 0: 
                    predicted_ids = torch.argmax(outputs, dim=-1)
                    transcription = processor.batch_decode(predicted_ids)
                    ada_wer = wer(list(texts), list(transcription))
                    print("adapt-1 WER:  ", ada_wer)
                    # print(texts, transcription)
                    transcriptions_1 += transcription

                if i == 2: 
                    predicted_ids = torch.argmax(outputs, dim=-1)
                    transcription = processor.batch_decode(predicted_ids)
                    ada_wer = wer(list(texts), list(transcription))
                    print("adapt-3 WER:  ", ada_wer)
                    # print(texts, transcription)
                    transcriptions_3 += transcription

                if i == 4: 
                    predicted_ids = torch.argmax(outputs, dim=-1)
                    transcription = processor.batch_decode(predicted_ids)
                    ada_wer = wer(list(texts), list(transcription))
                    print("adapt-5 WER:  ", ada_wer)
                    # print(texts, transcription)
                    transcriptions_5 += transcription

                if i == 9: 
                    predicted_ids = torch.argmax(outputs, dim=-1)
                    transcription = processor.batch_decode(predicted_ids)
                    ada_wer = wer(list(texts), list(transcription))
                    print("adapt-10 WER: ", ada_wer)
                    werr = ori_wer - ada_wer
                    werrs.append(werr)
                    # print(texts, transcription)
                    transcriptions_10 += transcription
                    
                if i == 19: 
                    predicted_ids = torch.argmax(outputs, dim=-1)
                    transcription = processor.batch_decode(predicted_ids)
                    ada_wer = wer(list(texts), list(transcription))
                    print("adapt-20 WER: ", ada_wer)
                    # print(texts, transcription)
                    transcriptions_20 += transcription

                if  i == 39: 
                    predicted_ids = torch.argmax(outputs, dim=-1)
                    transcription = processor.batch_decode(predicted_ids)
                    ada_wer = wer(list(texts), list(transcription))
                    print("adapt-40 WER: ", ada_wer)
                    # print(texts, transcription)
                    transcriptions_40 += transcription
        
        del input_values
        torch.cuda.empty_cache()
        gt_texts += texts


    print("asr:", asr)
    print(f'non-adapted count = {count}')
    print(f'dataset num = {len(dataset)}')
    print("original WER:", wer(gt_texts, ori_transcriptions))
    if steps >= 10: 
        print("TTA-1 WER:", wer(gt_texts, transcriptions_1))
        print("TTA-3 WER:", wer(gt_texts, transcriptions_3))
        print("TTA-5 WER:", wer(gt_texts, transcriptions_5))
        print("TTA-10 WER:", wer(gt_texts, transcriptions_10))
    if steps >= 20: 
        print("TTA-20 WER:", wer(gt_texts, transcriptions_20))
    if steps >= 40:
        print("TTA-40 WER:", wer(gt_texts, transcriptions_40))
    print('------------------------------------')


    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    with open(os.path.join(log_dir, exp_name), 'w') as f: 
        f.write(f"original WER: {wer(gt_texts, ori_transcriptions)}\n")
        if steps >= 10: 
            f.write(f"TTA-1 WER: {wer(gt_texts, transcriptions_1)}\n")
            f.write(f"TTA-3 WER: {wer(gt_texts, transcriptions_3)}\n")
            f.write(f"TTA-5 WER: {wer(gt_texts, transcriptions_5)}\n")
            f.write(f"TTA-10 WER: {wer(gt_texts, transcriptions_10)}\n")
        if steps >= 20:
            f.write(f"TTA-20 WER: {wer(gt_texts, transcriptions_20)}\n")
        if steps >= 40:
            f.write(f"TTA-40 WER: {wer(gt_texts, transcriptions_40)}\n")
        f.write(f'eposidic? {episodic}\n')
        f.write(f'lr = {lr}\n')
        f.write(f'optim = {opt}\n')
        f.write(f'step = {steps}\n')
        f.write(f'em_coef = {em_coef}\n')
        f.write(f'reweight = {reweight}\n')
        f.write(f'batch size = {batch_size}\n')
        f.write(f'temperature = {temp}\n')
        f.write(f'non_blank = {str(non_blank)}\n')
        f.write(f'extra_noise = {extra_noise}\n')
        f.write(f'noise_type = {noise_type}\n')
        f.write(f'noise_snr = {noise_snr}\n')
        f.write(f'scheduler = {str(scheduler)}\n')
        f.write(f'div_coef = {str(div_coef)}\n')
        f.write(f'bias_only = {str(bias_only)}\n')
        f.write(f'train_feature = {str(train_feature)}\n')
        f.write(f'train_all = {str(train_all)}\n')
        f.write(f'train_LN = {str(train_LN)}\n')
    
    csv_path = os.path.join(log_dir, exp_name+'.csv')
    df = pd.DataFrame({'duration': durations, 'WERR': werrs})
    df.to_csv(csv_path)
    









