import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch import distributed as dist
from torch.optim.lr_scheduler import LambdaLR

from typing import Dict, List, Literal, Optional, Tuple, Union
import copy


from scipy.optimize import linear_sum_assignment

from models.continual.continual_model import ContinualOCL




class PostRelpay(ContinualOCL):
    def __init__(
        self,
        replay_size,
        n_epochs,
        lr,
        weight_decay,
        net,
        num_task: int = 2,
        isolation: bool = None,
        isolation_parameters: List[str] = [],
    ):
        super(PostRelpay, self).__init__(net, num_task, isolation, isolation_parameters)

        self.replay_size = replay_size

        self.loss_fn = nn.MSELoss()

        self.slots_rec = []
        self.recons_rec = []

        self.n_epochs = n_epochs
        self.lr = lr
        self.weight_decay = weight_decay
        
        self.dataset = None
        self.is_main = None

        self.warmup_steps_pct = 0.02
        self.decay_steps_pct = 0.2
        self.scheduler_gamma = 0.5



    def warm_up_step(self, dataset, is_main):
        if self.current_task() >= 0:
        
            slots_rec = []
            recons_rec = []
            batch_size = len(dataset) // 500
            step = (len(dataset) // batch_size) // 2
            with torch.no_grad():
                for itr, i in enumerate(range(0, len(dataset), batch_size)):
                    if itr % step == 0:
                        print(f'- Record for Replay [{itr:4.0f}/{len(dataset) // batch_size} itr]') 
                    image = torch.stack(
                        [dataset.__getitem__(j)['image'] 
                         for j in range(i, min(i + batch_size, len(dataset)))]
                    )
                    image = image.cuda()
                    outputs = self.net(image.cuda())
                    slots = outputs['representation']
                    reconstruction = outputs['reconstruction']
                    reconstruction = (reconstruction + 1.0) * 0.5
                    slots_rec.append(slots)
                    recons_rec.append(reconstruction)

                slots_rec = torch.cat(slots_rec)
                recons_rec = torch.cat(recons_rec)

                if not is_main:
                    slots_rec = torch.zeros_like(slots_rec)
                    recons_rec = torch.zeros_like(recons_rec)

                dist.barrier()
                dist.broadcast(slots_rec, src=0)
                dist.broadcast(recons_rec, src=0)
                slots_rec, recons_rec = slots_rec.clone().detach().cpu(), recons_rec.clone().detach().cpu()

                print(f'- Record slots {slots_rec.shape} for Replay') 
                print(f'- Record attns {recons_rec.shape} for Replay') 

                slots_rec.requires_grad = False
                recons_rec.requires_grad = False

                self.slots_rec.append(slots_rec)
                self.recons_rec.append(recons_rec)

            

            self.net.train()

    def post_replay(self,):
        if self.current_task() >= 1:
            
            print(f'\n\n- Training with Replay for Semantic Drift\n') 

            dataset = self.dataset
            self.net.eval()

            parameters = []
            for name, param in self.named_parameters():
                if param.requires_grad == True:
                    parameters.append(param)
                    print('*** Training: ', name, param.size())  
                else:
                    print('*** Freezing: ', name, param.size())  

            optimizer = optim.Adam(parameters, lr=self.lr, weight_decay=self.weight_decay)
            
            replay_size = self.replay_size
            if self.replay_size == -1:
                replay_size = self.recons_rec.shape[0]
                current_dataset_size = len(self.dataset)
                replay_size = min(replay_size, current_dataset_size)

            print(f'\n\n- Training Decoder with {replay_size} Replay Samples\n') 



            replay_data = []
            for data in self.recons_rec:
                perm = torch.randperm(data.shape[0])[:self.replay_size]
                data = data[perm]
                replay_data.append(data.cpu())
            perm = torch.randperm(len(dataset))[:self.replay_size]
            data = torch.stack([dataset.__getitem__(j.item())['image'] for j in perm]).cpu()
            replay_data.append(data)
            replay_data = torch.cat(replay_data, dim=0)
            len_replay = replay_data.shape[0]


            val_dataloader = torch.utils.data.DataLoader(
                replay_data,
                batch_size=32,
                num_workers=4,
                pin_memory=True,
                drop_last=True,
                shuffle=True,
            )

            warmup_steps_pct = self.warmup_steps_pct
            total_steps = len(val_dataloader) * self.n_epochs
            decay_steps_pct = self.decay_steps_pct
            scheduler_gamma = self.scheduler_gamma
            def warm_and_decay_lr_scheduler(step: int):
                warmup_steps = warmup_steps_pct * total_steps
                decay_steps = decay_steps_pct * total_steps
                assert step < total_steps+1
                if step < warmup_steps:
                    factor = step / warmup_steps
                else:
                    factor = 1
                factor *= scheduler_gamma ** (step / decay_steps)
                return factor
            scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=warm_and_decay_lr_scheduler)

            batch_size = 32
            step = (len_replay // batch_size) // 4
            perm = torch.randperm(len_replay)
            for epoch in range(self.n_epochs):
                avg_loss = 0
                for itr, image in enumerate(val_dataloader):

                    image = image.cuda()

                    optimizer.zero_grad()

                    outputs = self.net(image)
                    loss = outputs['loss']

                    loss.backward()
                    optimizer.step()
                    scheduler.step()

                    if itr % step == 0:
                        print(f'- Training with Replay' + ' '.join([
                            f'[{itr:4.0f}/{len(val_dataloader)} itr]'
                            f'[{epoch:3.0f}/{self.n_epochs} ep]',
                            f'[loss: {loss.item():1.6f}]',
                            f'[lr: {optimizer.param_groups[0]["lr"]:1.6f}]',
                        ])) 

                    avg_loss += loss.item()
                print(f'>>> ', ' '.join([
                    f'[{epoch:3.0f}/{self.n_epochs} ep]',
                    f'[avg loss: {avg_loss/500:1.6f}]',
                ]), '\n')


    def end_task(self, **kwargs):
        self.warm_up_step(self.dataset, self.is_main)
        self.ene_task_()


    def begin_task(self, dataset, is_main, end_epoch, **kwargs):
        self.begin_task_()
        self.current_epoch = 0
        self.end_epoch = end_epoch
        self.dataset = dataset
        self.is_main = is_main


    def inter_task(self, **kwargs):
        if self.current_epoch == self.end_epoch-1:
            self.post_replay()
        self.current_epoch += 1


    def forward(self, x, index=None):
        outputs = self.forward_(x)
        outputs.update({
            'replay_loss': 0
        })

        return outputs
