import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from utils.agent_utils import action_one_hot

# This is used by Badminton datakloader
def DataPrepocessing(df, min_len=3, player_name='CHOU Tien Chen'):
    Trajs = []
    df_way = df[['rally', 'player',
                 'player_type', 'opponent_type',
                 'player_location_x', 'player_location_y', 
                 'moving_x', 'moving_y', 
                 'opponent_location_x', 'opponent_location_y', 
                 'landing_x', 'landing_y']].copy()
    
    df_way = df_way[df_way["player_type"] != 11].reset_index(drop=True) ##
    df_way = df_way[df_way["opponent_type"] != 11].reset_index(drop=True) ##
    df_way['group'] = (df_way['rally'] != df_way['rally'].shift()).cumsum()
    group_sizes = df_way.groupby('group').size()
    valid_groups = group_sizes[group_sizes > min_len].index
    df_way = df_way[df_way['group'].isin(valid_groups)]#.drop(columns='group')
    g = df_way.groupby('group')
    
    for _, rally in g:
        traj = []
        rally = drop_row(rally)
        
        for i in range(len(rally)):
            row = rally.iloc[i]
            prev_landing = (0.0, 0.0) if i == 0 else (rally.iloc[i - 1]['landing_x'], rally.iloc[i - 1]['landing_y'])
            opp_type = action_one_hot(torch.tensor([0]) if i == 0 else torch.tensor([int(row['opponent_type'])]), num_classes=11)
            player_type = action_one_hot(torch.tensor([int(row['player_type'])]), num_classes=11)
            
            if row['player'] == player_name:
                if i == 0:
                    prev_state = torch.zeros(17)
                    prev_action = torch.zeros(30)
                else:
                    prev_state = opp_s
                    prev_action = opp_a
                
                state = torch.cat([
                    opp_type,
                    torch.tensor([
                        row['player_location_x'], row['player_location_y'],
                        row['opponent_location_x'], row['opponent_location_y'],
                        prev_landing[0], prev_landing[1]
                    ], dtype=torch.float32)
                ])

                action = torch.cat([
                    player_type,
                    torch.tensor([
                        row['landing_x'], row['landing_y'],
                        row['moving_x'], row['moving_y']
                    ], dtype=torch.float32),
                    opp_type,
                    torch.tensor([
                        prev_landing[0], prev_landing[1],
                        row['opponent_location_x'], row['opponent_location_y']
                    ], dtype=torch.float32),
                ])
            
                data = {
                    "prev_state": prev_state,
                    "prev_action": prev_action,
                    "state": state,
                    "action": action
                    }
                traj.append(data)
            
            else:
                opp_s = torch.cat([
                    opp_type,
                    torch.tensor([
                        row['player_location_x'], row['player_location_y'],
                        row['opponent_location_x'], row['opponent_location_y'],
                        prev_landing[0], prev_landing[1]
                    ], dtype=torch.float32)
                ])
                
                opp_a = torch.cat([
                    player_type,
                    torch.tensor([
                        row['landing_x'], row['landing_y'],
                        row['moving_x'], row['moving_y']
                    ], dtype=torch.float32),
                    opp_type,
                    torch.tensor([
                        prev_landing[0], prev_landing[1],
                        row['opponent_location_x'], row['opponent_location_y']
                    ], dtype=torch.float32),
                ])       
        Trajs.append(traj)

    #print(df_way)
    return Trajs


def drop_row(df):
    return df.dropna(subset=['player_type', 'opponent_type'])