import os 
import random
import numpy as np
from copy import deepcopy
from tqdm import tqdm

import torch
from torch import nn
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
from jiwer import wer


def set_seed(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic=True


def setup_optimizer(params, params2, opt_name='AdamW', lr=1e-4, lr2=1e-4, beta=0.9, weight_decay=0., scheduler=None, step_size=1, gamma=0.7):
    opt = getattr(torch.optim, opt_name)
    print(f'[INFO]    optimizer: {opt}')
    print(f'[INFO]    scheduler: {scheduler}')
    if opt_name == 'Adam':       
        optimizer = opt(params,
                lr=lr,
                betas=(beta, 0.999),
                weight_decay=weight_decay)
    else: 
        optimizer = [opt(params, lr=lr, weight_decay=weight_decay), opt(params2, lr=lr2, weight_decay=weight_decay)]
    
    if scheduler is not None: 
        return optimizer, eval(scheduler)(optimizer, step_size=step_size, gamma=gamma)
    else: 
        return optimizer, None


def softmax_entropy(x, dim=2):
    # Entropy of softmax distribution from logits
    return -(x.softmax(dim) * x.log_softmax(dim)).sum(dim)

def mcc_loss(x, reweight=False, dim=2, class_num=32):
    p = x.softmax(dim) # (1, L, D)
    p = p.squeeze(0) # (L, D)
    if reweight: # (1, L, D) * (L, 1) 
        target_entropy_weight = softmax_entropy(x, dim=2).detach().squeeze(0) # instance-wise entropy (1, L, D)
        target_entropy_weight = 1 + torch.exp(-target_entropy_weight) # (1, L)
        target_entropy_weight = x.shape[1] * target_entropy_weight / torch.sum(target_entropy_weight)
        cov_matrix_t = p.mul(target_entropy_weight.view(-1, 1)).transpose(1, 0).mm(p)
    else:    
        cov_matrix_t = p.transpose(1, 0).mm(p) # (D, L) * (L, D) -> (D, D)

    cov_matrix_t = cov_matrix_t / torch.sum(cov_matrix_t, dim=1)
    mcc_loss = (torch.sum(cov_matrix_t) - torch.trace(cov_matrix_t)) / class_num
   
    return mcc_loss

def tc_reg_loss(x, non_blank):
    # temporal reg
    # x (1, C, L)
    x_trans = x.transpose(1,2) # (1, L, C)

    k = 1
    att = torch.matmul(x_trans, x)
    att = torch.softmax(att, dim=-1) # (1, L, L)
    att_x = torch.matmul(att, x_trans) + x_trans # (1, L, C)   
   
    non_blank = non_blank[:, k:]

    tc_loss = att_x[:,k:][non_blank] - att_x[:, :-k][non_blank]
    tc_loss = torch.norm(tc_loss, p=2, dim=-1).mean(0)
    
    return tc_loss


def collect_params(model, bias_only=False, train_feature=False, train_all=False, train_LN=True):
    """Collect the affine scale + shift parameters from layer norms.
    """
    params = []
    names = []
    trainable = []
    if bias_only:
        trainable = ['bias']
    else: 
        trainable = ['weight', 'bias']

    
    for nm, m in model.named_modules():
        print(nm)
        if train_LN: 
            if isinstance(m, nn.LayerNorm) or isinstance(m, nn.GroupNorm):
                for np, p in m.named_parameters():
                    if f"{nm}.{np}" in names:
                            continue
                    if np in trainable:  
                        p.requires_grad = True
                        params.append(p)
                        names.append(f"{nm}.{np}")
        if train_feature:
            if len(str(nm).split('.')) > 1:
                if str(nm).split('.')[1] == 'feature_extractor' or str(nm).split('.')[1] == 'feature_projection':
                    for np, p in m.named_parameters():
                        if f"{nm}.{np}" in names:
                            continue
                        p.requires_grad = True
                        params.append(p)
                        names.append(f"{nm}.{np}")
                        
        if train_all: 
            for np, p in m.named_parameters():
                p.requires_grad = True
                params.append(p)
                names.append(f"{nm}.{np}")
            
    return params, names



def copy_model_and_optimizer(model, optimizer, scheduler):
    """Copy the model and optimizer states for resetting after adaptation."""
    model_state = deepcopy(model.state_dict())
    optimizer_state = [deepcopy(optimizer[0].state_dict()), deepcopy(optimizer[1].state_dict())]
    if scheduler is not None:
        scheduler_state = deepcopy(scheduler.state_dict())
        return model_state, optimizer_state, scheduler_state
    else:
        return model_state, optimizer_state, None

def load_model_and_optimizer(model, optimizer, model_state, optimizer_state, scheduler_state):
    """Restore the model and optimizer states from copies."""
    model.load_state_dict(model_state, strict=True)
    optimizer[0].load_state_dict(optimizer_state[0])
    optimizer[1].load_state_dict(optimizer_state[1])

    if scheduler is not None:
        scheduler.load_state_dict(scheduler_state)
        return model, optimizer, scheduler
    else: 
        return model, optimizer, None
    

def configure_model(model):
    """Configure model for use with tent."""
    model.requires_grad_(False)
    return model

def forward_and_adapt(x, model, optimizer, em_coef=0.9, step=0, reweight=False, temp=1., not_blank=True, repeat_inference=True):

    # forward
    if step < 10:
        outputs = model(x).logits
        predicted_ids = torch.argmax(outputs, dim=-1)
        non_blank = torch.where(predicted_ids != 0, 1, 0).bool() 
        
        # adapt
        loss = 0
        if em_coef > 0: 
            if not_blank:      
                e = softmax_entropy(outputs / temp)
                e_non_blank = e[non_blank]
                
                weight = 1/(1+torch.exp(-e_non_blank))
                e_loss = (weight*e_non_blank).mean()

            else: 
                e_loss = softmax_entropy(outputs / temp).mean(0).mean() 
            loss += e_loss * em_coef
        if 1 - em_coef > 0: 
            c_loss = mcc_loss(outputs / temp, reweight)
            loss += c_loss * (1 - em_coef)
        
        model.zero_grad()
        loss.backward()
        optimizer[0].step()

    else:
        outputs = model(x).logits

        feats = model.wav2vec2.feature_extractor(x)
        predicted_ids = torch.argmax(outputs, dim=-1)
        non_blank = torch.where(predicted_ids != 0, 1, 0).bool() 

        # adapt
        loss2 = 0
        if em_coef > 0: 
            if not_blank:      
                e_loss2 = softmax_entropy(outputs / temp).mean(0).mean()

            else: 
                e_loss2 = softmax_entropy(outputs / temp).mean(0).mean() 
            loss2 += e_loss2 * em_coef
        if 1 - em_coef > 0: 
            c_loss2 = mcc_loss(outputs / temp, reweight)
            loss2 += c_loss2 * (1 - em_coef)
        
        tc_loss = tc_reg_loss(feats, non_blank)
        loss2 += 0.3*tc_loss

        model.zero_grad()
        loss2.backward()
        optimizer[1].step()

    # inference again
    if repeat_inference:
        with torch.no_grad():
            outputs = model(x).logits
            
    return outputs

import argparse

if __name__ == '__main__':
    SAMPLE_RATE = 16000
    parser = argparse.ArgumentParser(description="TTA ASR")
    parser.add_argument('--asr', type=str, default="facebook/wav2vec2-base-960h")
    parser.add_argument('--steps', type=int, default=40)
    parser.add_argument('--episodic', action='store_true')
    parser.add_argument('--opt', type=str, default='AdamW')
    parser.add_argument('--dataset_name', type=str, default='librispeech')
    parser.add_argument('--dataset_dir', type=str, default='/path/to/LibriSpeech')
    parser.add_argument('--split', default=['test-other'])
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--lr2', type=float, default=1e-4)
    parser.add_argument('--em_coef', type=float, default=1.)
    parser.add_argument('--reweight', action='store_true')
    parser.add_argument('--bias_only', action='store_true')
    parser.add_argument('--train_feature', action='store_true')
    parser.add_argument('--train_all', action='store_true')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--temp', type=float, default=2.5)
    parser.add_argument('--non_blank', action='store_true')
    parser.add_argument('--log_dir', type=str, default='./exps')
    parser.add_argument('--extra_noise', type=float, default=0.)
    parser.add_argument('--scheduler', default=None)

    args = parser.parse_args()
    asr = args.asr
    steps = args.steps
    episodic = args.episodic
    opt = args.opt
    dataset_dir = args.dataset_dir
    dataset_name = args.dataset_name
    split = args.split
    lr = args.lr
    em_coef = args.em_coef
    lr2 = args.lr2
    reweight = args.reweight
    batch_size = args.batch_size
    temp =  args.temp
    non_blank = args.non_blank
    log_dir = args.log_dir
    extra_noise = args.extra_noise
    scheduler = args.scheduler

    bias_only = args.bias_only
    train_feature = args.train_feature
    train_all = args.train_all
    skip_short_thd = None
    train_LN = True

    exp_name = 'lm_'+dataset_name+'_'+str(lr)+'_'+str(em_coef)+'_'+str(lr2)+'_'+str(steps)+'_'+str(temp)+'_'+asr.split('/')[-1]+'_'+'non_blank'+str(non_blank)+'_noise_'+str(extra_noise)+'_rew_'+str(reweight)

    set_seed(42)

    from data import load_dataset
    dataset = load_dataset(split, dataset_name, dataset_dir, batch_size, extra_noise)
    transcriptions_1 = []
    transcriptions_3 = []
    transcriptions_5 = []
    transcriptions_10 = []
    transcriptions_20 = []
    transcriptions_40 = []
    gt_texts = []
    ori_transcriptions = []
    durations = []
    werrs = []

    print('------------------------------------')
    print(f'exp: {exp_name}')
    print(f'eposidic? {episodic}')
    print(f'lr = {lr}')
    print(f'optim = {opt}')
    print(f'step = {steps}')
    print(f'em_coef = {em_coef}')
    print(f'lr2 = {lr2}')
    print(f'reweight = {reweight}')
    print(f'batch size = {batch_size}')
    print(f'temperature = {temp}')
    print(f'non_blank = {str(non_blank)}')
    print(f'extra_noise = {extra_noise}')
    print(f'scheduler = {str(scheduler)}')
    print(f'bias_only = {bias_only}')
    print(f'train_feature = {train_feature}')
    print(f'train_all = {train_all}')
    print(f'train_LN = {train_LN}')

    # load model and tokenizer
    processor = Wav2Vec2Processor.from_pretrained(asr, sampling_rate=SAMPLE_RATE, return_attention_mask=True)
    decoder_processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
    model = Wav2Vec2ForCTC.from_pretrained(asr).eval().cuda()        
    
    model = configure_model(model)
    params, param_names = collect_params(model, bias_only, train_feature, False, True)
    params2, param_names2 = collect_params(model, bias_only, False, False, train_LN)
    optimizer, scheduler = setup_optimizer(params, params2, opt, lr, lr2, scheduler=scheduler)

    if episodic: 
        model_state, optimizer_state, scheduler_state = copy_model_and_optimizer(model, optimizer, scheduler)

    
    #print(param_names2)
    count = 0
    transcription_dict = {}
    to_display = [19]
    for i in to_display:
        transcription_dict[i] = []

    beam_width = 5
    alpha = 0.5

    import time
    start = time.time()
    for batch in tqdm(dataset):
        lens, wavs, texts, files = batch
        
        inputs = processor(wavs, return_tensors="pt", padding="longest")
        input_values = inputs.input_values.cuda()
        duration = input_values.shape[1] / SAMPLE_RATE
        durations.append(duration)
        
        if episodic: 
            model, optimizer, scheduler = load_model_and_optimizer(model, optimizer, model_state, optimizer_state, scheduler_state)
        
        # vanilla forward 
        with torch.no_grad():
            outputs = model(input_values).logits
        predicted_ids = torch.argmax(outputs, dim=-1)
        ori_transcription = decoder_processor.batch_decode(outputs.detach().cpu().numpy(), beam_width=beam_width, alpha=alpha).text
        
        ori_transcriptions += ori_transcription
        ori_wer = wer(list(texts), list(ori_transcription))
        print("original WER: ", ori_wer)
        
        if skip_short_thd is not None: 
            if outputs.shape[1] <= skip_short_thd:
                print(f'do not adapt since length is {outputs.shape[1]}')
                count += 1
                continue
        
        for i in range(steps): 
            outputs = forward_and_adapt(input_values, model, optimizer, em_coef, i, reweight, temp, non_blank )
            
            if i in to_display:
                predicted_ids = torch.argmax(outputs, dim=-1)
                transcription = decoder_processor.batch_decode(outputs.detach().cpu().numpy(), beam_width=beam_width, alpha=alpha).text
                ada_wer = wer(list(texts), list(transcription))
                print("adapt-{} WER:  {}".format(i+1, ada_wer))
                if i == 19:
                    werr = ori_wer - ada_wer
                    werrs.append(werr)
                transcription_dict[i] += transcription
        
        del input_values
        torch.cuda.empty_cache()
        gt_texts += texts


    print("asr:", asr)
    print(f'non-adapted count = {count}')
    print(f'dataset num = {len(dataset)}')
    print("original WER:", wer(gt_texts, ori_transcriptions))

    wer_dict = {}
    for i, trans in transcription_dict.items():
        wer_dict[i] = wer(gt_texts, trans)
        print("TTA-{} WER: {}".format(i+1, wer_dict[i]))
    print('------------------------------------')


    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    with open(os.path.join(log_dir, exp_name), 'w') as f: 
        f.write(f"original WER: {wer(gt_texts, ori_transcriptions)}\n")
        for i, wer_res in wer_dict.items():
            f.write(f"TTA-{i+1} WER: {wer_res}\n")
 
        f.write(f'eposidic? {episodic}\n')
        f.write(f'lr = {lr}\n')
        f.write(f'lr2 = {lr2}\n')
        f.write(f'optim = {opt}\n')
        f.write(f'step = {steps}\n')
        f.write(f'em_coef = {em_coef}\n')
        f.write(f'reweight = {reweight}\n')
        f.write(f'batch size = {batch_size}\n')
        f.write(f'temperature = {temp}\n')
        f.write(f'non_blank = {str(non_blank)}\n')
        f.write(f'extra_noise = {extra_noise}\n')
        f.write(f'scheduler = {str(scheduler)}\n')
        f.write(f'bias_only = {str(bias_only)}\n')
        f.write(f'train_feature = {str(train_feature)}\n')
        f.write(f'train_all = {str(train_all)}\n')
        f.write(f'train_LN = {str(train_LN)}\n')











