import pickle
import torch
import numpy as np
import dataset
from torch.utils.data import DataLoader
from objective import CL_Loss, SP_Loss
from backbone import GPT2Encoder, LLMEncoder
import pytorch_lightning as pl
import os
import bitsandbytes as bnb

def get_inv_sigma_Ti(dim_Ti):
    dim_Ti -= 1 # number of sentences in the middle (remove first and last sentences)
    sigma_Ti = np.zeros(shape=(dim_Ti-1,dim_Ti-1))
    for i in range(dim_Ti-1):
        for j in range(dim_Ti-1):
            sigma_Ti[i,j] = min(i+1,j+1) * (dim_Ti- max(i+1,j+1)) / dim_Ti #(i+1) * (dim_Ti-np.arange(i+1,dim_Ti))/dim_Ti
    
    return np.linalg.inv(sigma_Ti)

def create_dataloader(dataset, config, shuffle=True):
    loader = DataLoader(
        dataset,
        batch_size=config['optim_params']['batch_size'],
        shuffle=shuffle,
        pin_memory=True,
        drop_last=False,
        num_workers=config['experiment_params']['data_loader_workers'],
    )
    return loader


class CL_Encoder(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self._set_dataset()
        self._set_language_encoder()

    def configure_optimizers(self):
        if self.config['model_params']['load_in_8bit']:
            optimizer = bnb.optim.Adam8bit(
                self.parameters(), 
                lr=self.config['optim_params']['learning_rate'],
                weight_decay=self.config['optim_params']['decay_factor'])
        elif self.config['optim_params']['optimizer_name'] == 'AdamW':
            optimizer = torch.optim.AdamW(
                self.parameters(),
                lr=self.config['optim_params']['learning_rate'],
                weight_decay=self.config['optim_params']['decay_factor'])
        elif self.config['optim_params']['optimizer_name'] == 'SGD':
            optimizer = torch.optim.SGD(
                self.parameters(),
                lr=self.config['optim_params']['learning_rate'],
                momentum=self.config['optim_params']['momentum'])
        return [optimizer], []

    def train_dataloader(self):
        return create_dataloader(self.train_dataset, self.config)

    def test_dataloader(self):
        return create_dataloader(self.test_dataset, self.config, shuffle=False)

    def _set_dataset(self):

        self.train_dataset = dataset.Dataset(
            model_name=self.config['model_params']['model_name'],
            train=True,
            config=self.config['data_params']
        )
        self.test_dataset = dataset.Dataset(
            model_name=self.config['model_params']['model_name'],
            train=False,
            config=self.config['data_params']
        )


    def _set_language_encoder(self):
        if self.config['model_params']['model_name'] == 'gpt2':
            self.model = GPT2Encoder(
                hidden_dim=self.config['model_params']['hidden_size'],
                latent_dim=self.config['model_params']['latent_dim'],
                )
        else:
            self.model = LLMEncoder(
                model_name=self.config['model_params']['model_name'],
                hidden_dim=self.config['model_params']['hidden_size'],
                latent_dim=self.config['model_params']['latent_dim'],
                load_in_8bit=self.config['model_params']['load_in_8bit'],
            )

        self.model.model.resize_token_embeddings(len(self.train_dataset.tokenizer))


    def forward(self, input_ids, attention_mask):
        feats = self.model.forward(input_ids=input_ids, attention_mask=attention_mask)
        return feats

    def get_feats(self, obs):
        input_ids_i, attention_mask_i = self.train_dataset.tokenize_text(
            obs, device=self.config['experiment_params']['device'])
        input_ids_i = input_ids_i[:, :self.train_dataset.max_length]
        attention_mask_i = attention_mask_i[:, :self.train_dataset.max_length]
        feats_i = self.forward(input_ids=input_ids_i, attention_mask=attention_mask_i)
        return feats_i

    def get_losses_for_batch(self, batch):
        torch.cuda.empty_cache()
        obs_0 = batch['y0']
        obs_t = batch['yt']
        obs_T = batch['yT']
        t_s = batch['t1'].float()
        ts = batch['t2'].float()
        Ts = batch['T'].float()
        feats_0 = self.get_feats(obs_0)
        feats_t = self.get_feats(obs_t)
        feats_T = self.get_feats(obs_T)
        # log_q_y_tp1 = self.model.get_log_q(feats_t)
        loss_fn = CL_Loss(
            z_0=feats_0,
            z_t=feats_t,
            z_T=feats_T,
            t_=t_s,
            t=ts,
            T=Ts,
            alpha=0,
            var=0,
            # log_q_y_T=log_q_y_tp1,
            # loss_type=self.config['loss_params']['name'],
            eps=self.config['model_params']['eps'],
            max_seq_len=batch['total_t'].float(),
        )
        loss = loss_fn.get_loss()
        return loss

    def training_step(self, batch, config):
        loss = self.get_losses_for_batch(batch)
        self.log('train_loss', loss.cpu().detach(), prog_bar=True, on_step=True, sync_dist=True, batch_size=self.config['optim_params']['batch_size'])
        return loss

    def test_step(self, batch, config):
        loss = self.get_losses_for_batch(batch=batch)
        self.log('test_loss', loss.cpu().detach(), prog_bar=True, on_step=True,sync_dist=True, batch_size=self.config['optim_params']['batch_size'])
        return loss

    def save(self, directory):
        torch.save(self.model.mlp.state_dict(), os.path.join(directory, "mlp.pt"))
        torch.save(self.model.feature_extractor.state_dict(), os.path.join(directory, "feature_extractor.pt"))


class SP_Encoder(CL_Encoder):
    def __init__(self, config):
        super().__init__(config)
        self.sigma_info = {"sigma": None, "sigma_inv": None}
        self.all_sigma = {"sigma": [], "var": []}
        self._set_dataset()
        self._set_language_encoder()
        self._set_eval_dataloader()
        #self._load_sigma_Ti_inv()

    def _load_sigma_Ti_inv(self):    
        with open(self.config['doc_params']['sigma_ti_path'], 'rb') as f:
            self.all_sigma_Ti_inv = pickle.load(f)

    def _set_eval_dataloader(self):
        
        eval_train_dataset = dataset.Dataset(
            model_name=self.config['model_params']['model_name'],
            train=True,
            config=self.config['data_params'],
            single=True,
        )
        
        self.eval_train_dataloader = create_dataloader(eval_train_dataset, self.config, shuffle=False)

    def get_losses_for_batch(self, batch):
        torch.cuda.empty_cache()
        obs_1 = batch['y1']
        obs_2 = batch['y2']
        obs_3 = batch['y3']
        
        t1 = batch['t1'].float()
        t2 = batch['t2'].float()
        t3 = batch['t3'].float()
        total_t = batch['total_t'].float()
        
        feats_1 = self.get_feats(obs_1)
        feats_2 = self.get_feats(obs_2)
        feats_3 = self.get_feats(obs_3)
        
        obs_first = batch['first_sent']
        obs_last = batch['last_sent']
        
        feats_first = self.get_feats(obs_first)
        feats_last = self.get_feats(obs_last)
        
        sigma_inv = self.sigma_info['sigma_inv']
        
        loss_fn = SP_Loss(
            z1=feats_1,
            z2=feats_2,
            z3=feats_3,
            t1=t1,
            t2=t2,
            t3=t3,
            total_t=total_t,
            z_first=feats_first,
            z_last=feats_last,
            sigma_inv=sigma_inv,
        )
        loss = loss_fn.get_loss()
        return loss
    
    def on_train_epoch_start(self):
        
        dim = self.config['model_params']['latent_dim']
        doc_params = self.config['doc_params']
        
        if doc_params['sigma_type'] in ["standard", "var"] :
            
            with torch.no_grad():
                self.model.eval()
                
                sigma = torch.zeros(dim, dim).to(self.device)
                
                curr_doc_id = 0
                total_ti_minus_1 = 0
                
                var = 0
                print('Start estimating sigmas...')
                for batch in self.eval_train_dataloader:
                    
                    feats = self.get_feats(batch['text']).detach()
                    
                    for feat, sent_id, doc_len in zip(feats, batch['sentence_id'], batch['total_doc_sentences']):
                        if sent_id == 0:
                            curr_doc_feats = [feat]
                            curr_doc_id += 1
                        elif sent_id == doc_len - 1:
                            total_ti_minus_1 += doc_len - 1
                            
                            curr_doc_feats.append(feat)
                            curr_doc_feats = torch.stack(curr_doc_feats)
                            
                            t = torch.linspace(start=1, end=doc_len-1, steps=1).to(self.device)
                            mu = curr_doc_feats[0] + (curr_doc_feats[-1] - curr_doc_feats[0]) * t / doc_len
                            
                            s_mu_diff = curr_doc_feats - mu
                            s_mu_diff = s_mu_diff[1:-1, :].transpose(0, 1) # remove first and last sentence
                            
                            sigma_Ti_inv = torch.from_numpy(get_inv_sigma_Ti(doc_len)).to(self.device).to(s_mu_diff.dtype)
                            #sigma_Ti_inv = torch.from_numpy(self.all_sigma_Ti_inv[doc_len.item()]).to(self.device).to(s_mu_diff.dtype)
                            
                            sigma += s_mu_diff @ sigma_Ti_inv @ s_mu_diff.transpose(0, 1)
                            var += torch.trace(sigma)
                        else:
                            curr_doc_feats.append(feat)
                
                if doc_params['sigma_type'] == "standard":
                    sigma = sigma / total_ti_minus_1 * doc_params['sigma_multiplier']
                    sigma_inv = torch.inverse(sigma + doc_params['sigma_eps'] * torch.eye(dim).to(self.device))
                    
                    self.all_sigma["sigma"].append(sigma.cpu().detach().numpy())
                    
                    print(f'Initial: sigma_eig:{torch.linalg.eigvals(sigma)} \n  ')
                    print(f'Final: sigma:{sigma} \n     sigma_inv:{sigma_inv} \n ')
                    
                else:
                    init_sigma = sigma / total_ti_minus_1
                    init_var = var / total_ti_minus_1 / dim
                    sigma = (1 - doc_params['var_eps']) * init_sigma + doc_params['var_eps'] * init_var * torch.eye(dim).to(self.device)
                    sigma = sigma * doc_params['sigma_multiplier']
                    sigma_inv = torch.inverse(sigma)
                    
                    self.all_sigma["sigma"].append(init_sigma.cpu().detach().numpy())
                    self.all_sigma["var"].append(init_var.cpu().detach().numpy())
                    
                    print(f'Initial: sigma_eig:{torch.linalg.eigvals(init_sigma)} \n    var:{init_var}')
                    print(f'Final: sigma:{sigma} \n     sigma_inv:{sigma_inv} \n ')
                    
                self.sigma_info['sigma'] = sigma
                self.sigma_info['sigma_inv'] = sigma_inv
                
            
        elif doc_params['sigma_type'] == "eye":
            sigma = torch.eye(dim).to(self.device)
            self.sigma_info['sigma'] = sigma
            self.sigma_info['sigma_inv'] = torch.inverse(sigma + doc_params['sigma_eps'] * torch.eye(dim).to(self.device))
        
        else:
            raise ValueError("Invalid sigma type")